Exemplo n.º 1
0
def extract_type(a, ufl_type):
    """Build a set of all objects of class ufl_type found in a.
    The argument a can be a Form, Integral or Expr."""
    if issubclass(ufl_type, Terminal):
        return set(o for e in iter_expressions(a) \
                     for o in traverse_terminals(e) \
                     if isinstance(o, ufl_type))
    return set(o for e in iter_expressions(a) \
                 for o in post_traversal(e) \
                 if isinstance(o, ufl_type))
Exemplo n.º 2
0
def extract_type(a, ufl_type):
    """Build a set of all objects of class ufl_type found in a.
    The argument a can be a Form, Integral or Expr."""
    if issubclass(ufl_type, Terminal):
        return set(o for e in iter_expressions(a) \
                     for o in traverse_terminals(e) \
                     if isinstance(o, ufl_type))
    return set(o for e in iter_expressions(a) \
                 for o in post_traversal(e) \
                 if isinstance(o, ufl_type))
Exemplo n.º 3
0
def extract_type(a, ufl_type):
    """Build a set of all objects of class ufl_type found in a.
    The argument a can be a Form, Integral or Expr."""
    if issubclass(ufl_type, Terminal):
        # Optimization
        return set(o for e in iter_expressions(a)
                   for o in traverse_unique_terminals(e)
                   if isinstance(o, ufl_type))
    else:
        return set(o for e in iter_expressions(a)
                   for o in unique_pre_traversal(e) if isinstance(o, ufl_type))
Exemplo n.º 4
0
def has_type(a, ufl_types):
    """Check if any class from ufl_types is found in a.
    The argument a can be a Form, Integral or Expr."""
    if issubclass(ufl_types, Expr):
        ufl_types = (ufl_types, )
    if all(issubclass(ufl_type, Terminal) for ufl_type in ufl_types):
        return any(isinstance(o, ufl_types) \
                   for e in iter_expressions(a) \
                   for o in traverse_terminals(e))
    return any(isinstance(o, ufl_types) \
               for e in iter_expressions(a) \
               for o in post_traversal(e))
Exemplo n.º 5
0
def has_type(a, ufl_types):
    """Check if any class from ufl_types is found in a.
    The argument a can be a Form, Integral or Expr."""
    if issubclass(ufl_types, Expr):
        ufl_types = (ufl_types,)
    if all(issubclass(ufl_type, Terminal) for ufl_type in ufl_types):
        return any(isinstance(o, ufl_types) \
                   for e in iter_expressions(a) \
                   for o in traverse_terminals(e))
    return any(isinstance(o, ufl_types) \
               for e in iter_expressions(a) \
               for o in post_traversal(e))
Exemplo n.º 6
0
def _extract_arguments(form):
    # This is a copy of extract_type in ufl.algorithms.analysis
    # without wrapping the result in a set
    return [
        o for e in iter_expressions(form) for o in traverse_unique_terminals(e)
        if isinstance(o, Argument)
    ]
Exemplo n.º 7
0
def has_exact_type(a, ufl_type):
    """Return if an object of class ufl_type can be found in a.
    The argument a can be a Form, Integral or Expr."""
    tc = ufl_type._ufl_typecode_
    if issubclass(ufl_type, Terminal):
        # Optimization
        traversal = traverse_unique_terminals
    else:
        traversal = unique_pre_traversal
    return any(o._ufl_typecode_ == tc for e in iter_expressions(a)
               for o in traversal(e))
Exemplo n.º 8
0
def form_iterator(form, iterator_type="nodes"):
    assert iterator_type in ("nodes", "integrals")
    if iterator_type == "nodes":
        for integral in form.integrals():
            for expression in iter_expressions(integral):
                for node in pre_traversal(expression): # pre_traversal algorithms guarantees that subsolutions are processed before solutions
                    yield node
    elif iterator_type == "integrals":
        for integral in form.integrals():
            yield integral
    else:
        raise ValueError("Invalid iterator type")
Exemplo n.º 9
0
def _extract_variables(a):
    """Build a list of all Variable objects in a,
    which can be a Form, Integral or Expr.
    The ordering in the list obeys dependency order."""
    handled = set()
    variables = []
    for e in iter_expressions(a):
        for o in unique_post_traversal(e):
            if isinstance(o, Variable):
                expr, label = o.ufl_operands
                if label not in handled:
                    variables.append(o)
                    handled.add(label)
    return variables
Exemplo n.º 10
0
def is_multilinear(form):
    "Check if form is multilinear in arguments."
    # An attempt at implementing is_multilinear using extract_argument_dependencies.
    # TODO: This has some false negatives for "multiple configurations". (Does it still? Needs testing!)
    # TODO: FFC probably needs a variant of this which checks for some sorts of linearity
    #       in Coefficients as well, this should be a fairly simple extension of the current algorithm.
    try:
        for e in iter_expressions(form):
            deps = extract_argument_dependencies(e)
            nargs = [len(d) for d in deps]
            if len(nargs) == 0:
                debug("This form is a functional.")
            if len(nargs) == 1:
                debug("This form is linear in %d arguments." % nargs[0])
            if len(nargs) > 1:
                warning("This form has more than one argument "\
                    "'configuration', it has terms that are linear in %s "\
                    "arguments respectively." % str(nargs))

    except NotMultiLinearException, msg:
        warning("Form is not multilinear, the offending term is: %s" % msg)
        return False
Exemplo n.º 11
0
def is_multilinear(form):
    "Check if form is multilinear in arguments."
    # An attempt at implementing is_multilinear using extract_argument_dependencies.
    # TODO: This has some false negatives for "multiple configurations". (Does it still? Needs testing!)
    # TODO: FFC probably needs a variant of this which checks for some sorts of linearity
    #       in Coefficients as well, this should be a fairly simple extension of the current algorithm.
    try:
        for e in iter_expressions(form):
            deps = extract_argument_dependencies(e)
            nargs = [len(d) for d in deps]
            if len(nargs) == 0:
                debug("This form is a functional.")
            if len(nargs) == 1:
                debug("This form is linear in %d arguments." % nargs[0])
            if len(nargs) > 1:
                warning("This form has more than one argument "\
                    "'configuration', it has terms that are linear in %s "\
                    "arguments respectively." % str(nargs))

    except NotMultiLinearException, msg:
        warning("Form is not multilinear, the offending term is: %s" % msg)
        return False
Exemplo n.º 12
0
def extract_terminals(a):
    "Build a set of all Terminal objects in a."
    return set(o for e in iter_expressions(a) \
                 for o in post_traversal(e) \
                 if isinstance(o, Terminal))
Exemplo n.º 13
0
def __unused__extract_classes(a):
    """Build a set of all unique Expr subclasses used in a.
    The argument a can be a Form, Integral or Expr."""
    return set(o._ufl_class_ for e in iter_expressions(a)
               for o in unique_pre_traversal(e))
Exemplo n.º 14
0
def validate_form(form): # TODO: Can we make this return a list of errors instead of raising exception?
    """Performs all implemented validations on a form. Raises exception if something fails."""
    errors = []
    warnings = []

    if not isinstance(form, Form):
        msg = "Validation failed, not a Form:\n%s" % repr(form)
        error(msg)
        #errors.append(msg)
        #return errors

    # FIXME: Add back check for multilinearity
    # Check that form is multilinear
    #if not is_multilinear(form):
    #    errors.append("Form is not multilinear in arguments.")

    # FIXME DOMAIN: Add check for consistency between domains somehow
    domains = set(t.domain()
                  for e in iter_expressions(form)
                  for t in traverse_terminals(e)) - set((None,))
    if not domains:
        errors.append("Missing domain definition in form.")

    top_domains = set(dom.top_domain() for dom in domains if dom is not None)
    if not top_domains:
        errors.append("Missing domain definition in form.")
    elif len(top_domains) > 1:
        warnings.append("Multiple top domain definitions in form: %s" % str(top_domains))

    # Check that cell is the same everywhere
    cells = set(dom.cell() for dom in top_domains) - set((None,))
    if not cells:
        errors.append("Missing cell definition in form.")
    elif len(cells) > 1:
        errors.append("Multiple cell definitions in form: %s" % str(cells))

    # Check that no Coefficient or Argument instance
    # have the same count unless they are the same
    coefficients = {}
    arguments = {}
    for e in iter_expressions(form):
        for f in traverse_terminals(e):
            if isinstance(f, Coefficient):
                c = f.count()
                if c in coefficients:
                    g = coefficients[c]
                    if not f is g:
                        errors.append("Found different Coefficients with " + \
                                   "same count: %s and %s." % (repr(f), repr(g)))
                else:
                    coefficients[c] = f

            elif isinstance(f, Argument):
                c = f.count()
                if c in arguments:
                    g = arguments[c]
                    if not f is g:
                        if c == -2: msg = "TestFunctions"
                        elif c == -1: msg = "TrialFunctions"
                        else: msg = "Arguments with same count"
                        msg = "Found different %s: %s and %s." % (msg, repr(f), repr(g))
                        errors.append(msg)
                else:
                    arguments[c] = f

    # Check that all integrands are scalar
    for expression in iter_expressions(form):
        if not is_true_ufl_scalar(expression):
            errors.append("Found non-scalar integrand expression:\n%s\n%s" % \
                              (str(expression), repr(expression)))

    # Check that restrictions are permissible
    for integral in form.integrals():
        # Only allow restricitions on interior facet integrals and surface measures
        if integral.measure().domain_type() in (Measure.INTERIOR_FACET, Measure.SURFACE):
            check_restrictions(integral.integrand(), True)
        else:
            check_restrictions(integral.integrand(), False)

    # Raise exception with all error messages
    # TODO: Return errors list instead, need to collect messages from all validations above first.
    if errors:
        final_msg = 'Found errors in validation of form:\n%s' % '\n\n'.join(errors)
        error(final_msg)
Exemplo n.º 15
0
        def separate(self):
            class _SeparatedParametrizedForm_Replacer(Transformer):
                def __init__(self, mapping):
                    Transformer.__init__(self)
                    self.mapping = mapping

                def operator(self, e, *ops):
                    if e in self.mapping:
                        return self.mapping[e]
                    else:
                        return e._ufl_expr_reconstruct_(*ops)

                def terminal(self, e):
                    return self.mapping.get(e, e)

            logger.log(DEBUG,
                       "***        SEPARATE FORM COEFFICIENTS        ***")

            logger.log(DEBUG, "1. Extract coefficients")
            integral_to_coefficients = dict()
            for integral in self._form.integrals():
                logger.log(
                    DEBUG,
                    "\t Currently on integrand " + str(integral.integrand()))
                self._coefficients.append(list())  # of ParametrizedExpression
                for e in iter_expressions(integral):
                    logger.log(DEBUG, "\t\t Expression " + str(e))
                    pre_traversal_e = [n for n in pre_traversal(e)]
                    tree_nodes_skip = [False for _ in pre_traversal_e]
                    for (n_i, n) in enumerate(pre_traversal_e):
                        if not tree_nodes_skip[n_i]:
                            # Skip expressions which are trivially non parametrized
                            if isinstance(n, Argument):
                                logger.log(
                                    DEBUG, "\t\t Node " + str(n) +
                                    " is skipped because it is an Argument")
                                continue
                            elif isinstance(n, Constant):
                                logger.log(
                                    DEBUG, "\t\t Node " + str(n) +
                                    " is skipped because it is a Constant")
                                continue
                            elif isinstance(n, MultiIndex):
                                logger.log(
                                    DEBUG, "\t\t Node " + str(n) +
                                    " is skipped because it is a MultiIndex")
                                continue
                            # Skip all expressions with at least one leaf which is an Argument
                            for t in traverse_terminals(n):
                                if isinstance(t, Argument):
                                    logger.log(
                                        DEBUG, "\t\t Node " + str(n) +
                                        " is skipped because it contains an Argument"
                                    )
                                    break
                            else:  # not broken
                                logger.log(
                                    DEBUG, "\t\t Node " + str(n) +
                                    " and its descendants are being analyzed for non-parametrized check"
                                )
                                # Make sure to skip all descendants of this node in the outer loop
                                # Note that a map with key set to the expression is not enough to
                                # mark the node as visited, since the same expression may appear
                                # on different sides of the tree
                                pre_traversal_n = [d for d in pre_traversal(n)]
                                for (d_i, d) in enumerate(pre_traversal_n):
                                    assert d == pre_traversal_e[
                                        n_i +
                                        d_i]  # make sure that we are marking the right node
                                    tree_nodes_skip[n_i + d_i] = True
                                # We might be able to strip any (non-parametrized) expression out
                                all_candidates = list()
                                internal_tree_nodes_skip = [
                                    False for _ in pre_traversal_n
                                ]
                                for (d_i, d) in enumerate(pre_traversal_n):
                                    if not internal_tree_nodes_skip[d_i]:
                                        # Skip all expressions where at least one leaf is not parametrized
                                        for t in traverse_terminals(d):
                                            if isinstance(t, BaseExpression):
                                                if wrapping.is_pull_back_expression(
                                                        t
                                                ) and not wrapping.is_pull_back_expression_parametrized(
                                                        t):
                                                    logger.log(
                                                        DEBUG,
                                                        "\t\t\t Descendant node "
                                                        + str(d) +
                                                        " causes the non-parametrized check to break because it contains a non-parametrized pulled back expression"
                                                    )
                                                    break
                                                else:
                                                    parameters = t._parameters
                                                    if "mu_0" not in parameters:
                                                        logger.log(
                                                            DEBUG,
                                                            "\t\t\t Descendant node "
                                                            + str(d) +
                                                            " causes the non-parametrized check to break because it contains a non-parametrized expression"
                                                        )
                                                        break
                                            elif isinstance(t, Constant):
                                                logger.log(
                                                    DEBUG,
                                                    "\t\t\t Descendant node " +
                                                    str(d) +
                                                    " causes the non-parametrized check to break because it contains a constant"
                                                )
                                                break
                                            elif isinstance(
                                                    t, GeometricQuantity
                                            ) and not isinstance(
                                                    t, FacetNormal
                                            ) and self._strict:
                                                logger.log(
                                                    DEBUG,
                                                    "\t\t\t Descendant node " +
                                                    str(d) +
                                                    " causes the non-parametrized check to break because it contains a geometric quantity and strict mode is on"
                                                )
                                                break
                                            elif wrapping.is_problem_solution_type(
                                                    t):
                                                if not wrapping.is_problem_solution(
                                                        t
                                                ) and not wrapping.is_problem_solution_dot(
                                                        t):
                                                    logger.log(
                                                        DEBUG,
                                                        "\t\t\t Descendant node "
                                                        + str(d) +
                                                        " causes the non-parametrized check to break because it contains a non-parametrized function"
                                                    )
                                                    break
                                                elif self._strict:  # solutions are not allowed, break
                                                    if wrapping.is_problem_solution(
                                                            t):
                                                        (
                                                            _, component,
                                                            solution
                                                        ) = wrapping.solution_identify_component(
                                                            t)
                                                        problem = get_problem_from_solution(
                                                            solution)
                                                        logger.log(
                                                            DEBUG,
                                                            "\t\t\t Descendant node "
                                                            + str(d) +
                                                            " causes the non-parametrized check to break because it contains the solution of "
                                                            + problem.name() +
                                                            " (exact problem decorator: "
                                                            + str(
                                                                hasattr(
                                                                    problem,
                                                                    "__is_exact__"
                                                                )) +
                                                            ", component: " +
                                                            str(component) +
                                                            ") and strict mode is on"
                                                        )
                                                        break
                                                    elif wrapping.is_problem_solution_dot(
                                                            t):
                                                        (
                                                            _, component,
                                                            solution_dot
                                                        ) = wrapping.solution_dot_identify_component(
                                                            t)
                                                        problem = get_problem_from_solution_dot(
                                                            solution_dot)
                                                        logger.log(
                                                            DEBUG,
                                                            "\t\t\t Descendant node "
                                                            + str(d) +
                                                            " causes the non-parametrized check to break because it contains the solution_dot of "
                                                            + problem.name() +
                                                            " (exact problem decorator: "
                                                            + str(
                                                                hasattr(
                                                                    problem,
                                                                    "__is_exact__"
                                                                )) +
                                                            ", component: " +
                                                            str(component) +
                                                            ") and strict mode is on"
                                                        )
                                                    else:
                                                        raise RuntimeError(
                                                            "Unidentified solution found"
                                                        )
                                        else:
                                            at_least_one_expression_or_solution = False
                                            for t in traverse_terminals(d):
                                                if isinstance(
                                                        t, BaseExpression
                                                ):  # which is parametrized, because previous for loop was not broken
                                                    at_least_one_expression_or_solution = True
                                                    logger.log(
                                                        DEBUG,
                                                        "\t\t\t Descendant node "
                                                        + str(d) +
                                                        " is a candidate after non-parametrized check because it contains the parametrized expression "
                                                        + str(t))
                                                    break
                                                elif wrapping.is_problem_solution_type(
                                                        t):
                                                    if wrapping.is_problem_solution(
                                                            t):
                                                        at_least_one_expression_or_solution = True
                                                        (
                                                            _, component,
                                                            solution
                                                        ) = wrapping.solution_identify_component(
                                                            t)
                                                        problem = get_problem_from_solution(
                                                            solution)
                                                        logger.log(
                                                            DEBUG,
                                                            "\t\t\t Descendant node "
                                                            + str(d) +
                                                            " is a candidate after non-parametrized check because it contains the solution of "
                                                            + problem.name() +
                                                            " (exact problem decorator: "
                                                            + str(
                                                                hasattr(
                                                                    problem,
                                                                    "__is_exact__"
                                                                )) +
                                                            ", component: " +
                                                            str(component) +
                                                            ")")
                                                        break
                                                    elif wrapping.is_problem_solution_dot(
                                                            t):
                                                        at_least_one_expression_or_solution = True
                                                        (
                                                            _, component,
                                                            solution_dot
                                                        ) = wrapping.solution_dot_identify_component(
                                                            t)
                                                        problem = get_problem_from_solution_dot(
                                                            solution_dot)
                                                        logger.log(
                                                            DEBUG,
                                                            "\t\t\t Descendant node "
                                                            + str(d) +
                                                            " is a candidate after non-parametrized check because it contains the solution_dot of "
                                                            + problem.name() +
                                                            " (exact problem decorator: "
                                                            + str(
                                                                hasattr(
                                                                    problem,
                                                                    "__is_exact__"
                                                                )) +
                                                            ", component: " +
                                                            str(component) +
                                                            ")")
                                                        break
                                            if at_least_one_expression_or_solution:
                                                all_candidates.append(d)
                                                pre_traversal_d = [
                                                    q for q in pre_traversal(d)
                                                ]
                                                for (q_i, q) in enumerate(
                                                        pre_traversal_d):
                                                    assert q == pre_traversal_n[
                                                        d_i +
                                                        q_i]  # make sure that we are marking the right node
                                                    internal_tree_nodes_skip[
                                                        d_i + q_i] = True
                                            else:
                                                logger.log(
                                                    DEBUG,
                                                    "\t\t\t Descendant node " +
                                                    str(d) +
                                                    " has not passed the non-parametrized because it is not a parametrized expression or a solution"
                                                )
                                # Evaluate candidates
                                if len(
                                        all_candidates
                                ) == 0:  # the whole expression was actually non-parametrized
                                    logger.log(
                                        DEBUG, "\t\t Node " + str(n) +
                                        " is skipped because it is a non-parametrized coefficient"
                                    )
                                    continue
                                elif len(
                                        all_candidates
                                ) == 1:  # the whole expression was actually parametrized
                                    logger.log(
                                        DEBUG, "\t\t Node " + str(n) +
                                        " will be accepted because it is a non-parametrized coefficient"
                                    )
                                    pass
                                else:  # part of the expression was not parametrized, and separating the non parametrized part may result in more than one coefficient
                                    if self._strict:  # non parametrized coefficients are not allowed, so split the expression
                                        logger.log(
                                            DEBUG, "\t\t\t Node " + str(n) +
                                            " will be accepted because it is a non-parametrized coefficient with more than one candidate. It will be split because strict mode is on. Its split coefficients are "
                                            + ", ".join([
                                                str(c) for c in all_candidates
                                            ]))
                                    else:  # non parametrized coefficients are allowed, so go on with the whole expression
                                        logger.log(
                                            DEBUG, "\t\t\t Node " + str(n) +
                                            " will be accepted because it is a non-parametrized coefficient with more than one candidate. It will not be split because strict mode is off. Splitting it would have resulted in more than one coefficient, namely "
                                            + ", ".join([
                                                str(c) for c in all_candidates
                                            ]))
                                        all_candidates = [n]
                                # Add the coefficient(s)
                                for candidate in all_candidates:

                                    def preprocess_candidate(candidate):
                                        if isinstance(candidate, Indexed):
                                            assert len(
                                                candidate.ufl_operands) == 2
                                            assert isinstance(
                                                candidate.ufl_operands[1],
                                                MultiIndex)
                                            if all([
                                                    isinstance(
                                                        index, FixedIndex)
                                                    for index in candidate.
                                                    ufl_operands[1].indices()
                                            ]):
                                                logger.log(
                                                    DEBUG,
                                                    "\t\t\t Preprocessed descendant node "
                                                    + str(candidate) +
                                                    " as an Indexed expression with fixed indices, resulting in a candidate "
                                                    + str(candidate) +
                                                    " of type " +
                                                    str(type(candidate)))
                                                return candidate  # no further preprocessing needed
                                            else:
                                                logger.log(
                                                    DEBUG,
                                                    "\t\t\t Preprocessed descendant node "
                                                    + str(candidate) +
                                                    " as an Indexed expression with at least one mute index, resulting in a candidate "
                                                    + str(candidate.
                                                          ufl_operands[0]) +
                                                    " of type " + str(
                                                        type(candidate.
                                                             ufl_operands[0])))
                                                return preprocess_candidate(
                                                    candidate.ufl_operands[0])
                                        elif isinstance(candidate, IndexSum):
                                            assert len(
                                                candidate.ufl_operands) == 2
                                            assert isinstance(
                                                candidate.ufl_operands[1],
                                                MultiIndex)
                                            assert all([
                                                isinstance(index, MuteIndex)
                                                for index in candidate.
                                                ufl_operands[1].indices()
                                            ])
                                            logger.log(
                                                DEBUG,
                                                "\t\t\t Preprocessed descendant node "
                                                + str(candidate) +
                                                " as an IndexSum expression, resulting in a candidate "
                                                +
                                                str(candidate.ufl_operands[0])
                                                + " of type " + str(
                                                    type(candidate.
                                                         ufl_operands[0])))
                                            return preprocess_candidate(
                                                candidate.ufl_operands[0])
                                        elif isinstance(candidate, ListTensor):
                                            candidates = set([
                                                preprocess_candidate(component)
                                                for component in
                                                candidate.ufl_operands
                                            ])
                                            if len(candidates) == 1:
                                                preprocessed_candidate = candidates.pop(
                                                )
                                                logger.log(
                                                    DEBUG,
                                                    "\t\t\t Preprocessed descendant node "
                                                    + str(candidate) +
                                                    " as an ListTensor expression with a unique preprocessed component, resulting in a candidate "
                                                    +
                                                    str(preprocessed_candidate)
                                                    + " of type " + str(
                                                        type(
                                                            preprocessed_candidate
                                                        )))
                                                return preprocess_candidate(
                                                    preprocessed_candidate)
                                            else:
                                                at_least_one_mute_index = False
                                                candidates_from_components = list(
                                                )
                                                for component in candidates:
                                                    assert isinstance(
                                                        component,
                                                        (ComponentTensor,
                                                         Indexed))
                                                    assert len(
                                                        component.ufl_operands
                                                    ) == 2
                                                    assert isinstance(
                                                        component.
                                                        ufl_operands[1],
                                                        MultiIndex)
                                                    if not all([
                                                            isinstance(
                                                                index,
                                                                FixedIndex) for
                                                            index in component.
                                                            ufl_operands[1].
                                                            indices()
                                                    ]):
                                                        at_least_one_mute_index = True
                                                    candidates_from_components.append(
                                                        preprocess_candidate(
                                                            component.
                                                            ufl_operands[0]))
                                                if at_least_one_mute_index:
                                                    candidates_from_components = set(
                                                        candidates_from_components
                                                    )
                                                    assert len(
                                                        candidates_from_components
                                                    ) == 1
                                                    preprocessed_candidate = candidates_from_components.pop(
                                                    )
                                                    logger.log(
                                                        DEBUG,
                                                        "\t\t\t Preprocessed descendant node "
                                                        + str(candidate) +
                                                        " as an ListTensor expression with multiple preprocessed components with at least one mute index, resulting in a candidate "
                                                        +
                                                        str(preprocessed_candidate
                                                            ) + " of type " +
                                                        str(
                                                            type(
                                                                preprocessed_candidate
                                                            )))
                                                    return preprocess_candidate(
                                                        preprocessed_candidate)
                                                else:
                                                    logger.log(
                                                        DEBUG,
                                                        "\t\t\t Preprocessed descendant node "
                                                        + str(candidate) +
                                                        " as an ListTensor expression with multiple preprocessed components with fixed indices, resulting in a candidate "
                                                        + str(candidate) +
                                                        " of type " +
                                                        str(type(candidate)))
                                                    return candidate  # no further preprocessing needed
                                        else:
                                            logger.log(
                                                DEBUG,
                                                "\t\t\t No preprocessing required for descendant node "
                                                + str(candidate) +
                                                " as a coefficient of type " +
                                                str(type(candidate)))
                                            return candidate

                                    preprocessed_candidate = preprocess_candidate(
                                        candidate)
                                    if preprocessed_candidate not in self._coefficients[
                                            -1]:
                                        self._coefficients[-1].append(
                                            preprocessed_candidate)
                                    logger.log(
                                        DEBUG,
                                        "\t\t\t Accepting descendant node " +
                                        str(preprocessed_candidate) +
                                        " as a coefficient of type " +
                                        str(type(preprocessed_candidate)))
                        else:
                            logger.log(
                                DEBUG, "\t\t Node " + str(n) +
                                " to be skipped because it is a descendant of a coefficient which has already been detected"
                            )
                if len(self._coefficients[-1]
                       ) == 0:  # then there were no coefficients to extract
                    logger.log(DEBUG,
                               "\t There were no coefficients to extract")
                    self._coefficients.pop(
                    )  # remove the (empty) element that was added to possibly store coefficients
                else:
                    logger.log(DEBUG, "\t Extracted coefficients are:")
                    for c in self._coefficients[-1]:
                        logger.log(DEBUG, "\t\t" + str(c))
                    integral_to_coefficients[integral] = self._coefficients[-1]

            logger.log(DEBUG,
                       "2. Prepare placeholders and forms with placeholders")
            for integral in self._form.integrals():
                # Prepare measure for the new form (from firedrake/mg/ufl_utils.py)
                measure = Measure(integral.integral_type(),
                                  domain=integral.ufl_domain(),
                                  subdomain_id=integral.subdomain_id(),
                                  subdomain_data=integral.subdomain_data(),
                                  metadata=integral.metadata())
                if integral not in integral_to_coefficients:
                    logger.log(
                        DEBUG, "\t Adding form for integrand " +
                        str(integral.integrand()) + " to unchanged forms")
                    self._form_unchanged.append(integral.integrand() * measure)
                else:
                    logger.log(
                        DEBUG,
                        "\t Preparing form with placeholders for integrand " +
                        str(integral.integrand()))
                    self._placeholders.append(list())  # of Constants
                    placeholders_dict = dict()
                    for c in integral_to_coefficients[integral]:
                        self._placeholders[-1].append(
                            Constant(self._NaN * ones(c.ufl_shape)))
                        placeholders_dict[c] = self._placeholders[-1][-1]
                        logger.log(
                            DEBUG, "\t\t " + str(placeholders_dict[c]) +
                            " is the placeholder for " + str(c))
                    replacer = _SeparatedParametrizedForm_Replacer(
                        placeholders_dict)
                    new_integrand = apply_transformer(integral.integrand(),
                                                      replacer)
                    self._form_with_placeholders.append(new_integrand *
                                                        measure)

            logger.log(
                DEBUG,
                "3. Assert that there are no parametrized expressions left")
            for form in self._form_with_placeholders:
                for integral in form.integrals():
                    for e in pre_traversal(integral.integrand()):
                        if isinstance(e, BaseExpression):
                            assert not (
                                wrapping.is_pull_back_expression(e)
                                and wrapping.
                                is_pull_back_expression_parametrized(e)
                            ), "Form " + str(
                                integral
                            ) + " still contains a parametrized pull back expression"
                            parameters = e._parameters
                            assert "mu_0" not in parameters, "Form " + str(
                                integral
                            ) + " still contains a parametrized expression"

            logger.log(DEBUG, "4. Prepare coefficients hash codes")
            for addend in self._coefficients:
                self._placeholder_names.append(list())  # of string
                for factor in addend:
                    self._placeholder_names[-1].append(
                        wrapping.expression_name(factor))

            logger.log(DEBUG, "5. Assert list length consistency")
            assert len(self._coefficients) == len(self._placeholders)
            assert len(self._coefficients) == len(self._placeholder_names)
            for (c, p, pn) in zip(self._coefficients, self._placeholders,
                                  self._placeholder_names):
                assert len(c) == len(p)
                assert len(c) == len(pn)
            assert len(self._coefficients) == len(self._form_with_placeholders)

            logger.log(DEBUG,
                       "*** DONE - SEPARATE FORM COEFFICIENTS - DONE ***")
            logger.log(DEBUG, "")
Exemplo n.º 16
0
def expression_iterator(expression):
    for subexpression in iter_expressions(expression):
        for node in pre_traversal(
                subexpression
        ):  # pre_traversal algorithms guarantees that subsolutions are processed before solutions
            yield node
Exemplo n.º 17
0
def validate_form(
    form
):  # TODO: Can we make this return a list of errors instead of raising exception?
    """Performs all implemented validations on a form. Raises exception if something fails."""
    errors = []

    if not isinstance(form, Form):
        msg = "Validation failed, not a Form:\n%s" % ufl_err_str(form)
        error(msg)
        # errors.append(msg)
        # return errors

    # FIXME: There's a bunch of other checks we should do here.

    # FIXME: Add back check for multilinearity
    # Check that form is multilinear
    # if not is_multilinear(form):
    #     errors.append("Form is not multilinear in arguments.")

    # FIXME DOMAIN: Add check for consistency between domains somehow
    domains = set(t.ufl_domain() for e in iter_expressions(form)
                  for t in traverse_unique_terminals(e)) - {None}
    if not domains:
        errors.append("Missing domain definition in form.")

    # Check that cell is the same everywhere
    cells = set(dom.ufl_cell() for dom in domains) - {None}
    if not cells:
        errors.append("Missing cell definition in form.")
    elif len(cells) > 1:
        errors.append("Multiple cell definitions in form: %s" % str(cells))

    # Check that no Coefficient or Argument instance have the same
    # count unless they are the same
    coefficients = {}
    arguments = {}
    for e in iter_expressions(form):
        for f in traverse_unique_terminals(e):
            if isinstance(f, Coefficient):
                c = f.count()
                if c in coefficients:
                    g = coefficients[c]
                    if f is not g:
                        errors.append("Found different Coefficients with " +
                                      "same count: %s and %s." %
                                      (repr(f), repr(g)))
                else:
                    coefficients[c] = f

            elif isinstance(f, Argument):
                n = f.number()
                p = f.part()
                if (n, p) in arguments:
                    g = arguments[(n, p)]
                    if f is not g:
                        if n == 0:
                            msg = "TestFunctions"
                        elif n == 1:
                            msg = "TrialFunctions"
                        else:
                            msg = "Arguments with same number and part"
                        msg = "Found different %s: %s and %s." % (msg, repr(f),
                                                                  repr(g))
                        errors.append(msg)
                else:
                    arguments[(n, p)] = f

    # Check that all integrands are scalar
    for expression in iter_expressions(form):
        if not is_true_ufl_scalar(expression):
            errors.append("Found non-scalar integrand expression: %s\n" %
                          ufl_err_str(expression))

    # Check that restrictions are permissible
    for integral in form.integrals():
        # Only allow restrictions on interior facet integrals and
        # surface measures
        if integral.integral_type().startswith("interior_facet"):
            check_restrictions(integral.integrand(), True)
        else:
            check_restrictions(integral.integrand(), False)

    # Raise exception with all error messages
    # TODO: Return errors list instead, need to collect messages from
    # all validations above first.
    if errors:
        final_msg = 'Found errors in validation of form:\n%s' % '\n\n'.join(
            errors)
        error(final_msg)
Exemplo n.º 18
0
def extract_terminals(a):
    "Build a set of all Terminal objects in a."
    return set(o for e in iter_expressions(a) \
                 for o in post_traversal(e) \
                 if isinstance(o, Terminal))
Exemplo n.º 19
0
Arquivo: checks.py Projeto: FEniCS/ufl
def validate_form(form):  # TODO: Can we make this return a list of errors instead of raising exception?
    """Performs all implemented validations on a form. Raises exception if something fails."""
    errors = []

    if not isinstance(form, Form):
        msg = "Validation failed, not a Form:\n%s" % ufl_err_str(form)
        error(msg)
        # errors.append(msg)
        # return errors

    # FIXME: There's a bunch of other checks we should do here.

    # FIXME: Add back check for multilinearity
    # Check that form is multilinear
    # if not is_multilinear(form):
    #     errors.append("Form is not multilinear in arguments.")

    # FIXME DOMAIN: Add check for consistency between domains somehow
    domains = set(t.ufl_domain()
                  for e in iter_expressions(form)
                  for t in traverse_unique_terminals(e)) - {None}
    if not domains:
        errors.append("Missing domain definition in form.")

    # Check that cell is the same everywhere
    cells = set(dom.ufl_cell() for dom in domains) - {None}
    if not cells:
        errors.append("Missing cell definition in form.")
    elif len(cells) > 1:
        errors.append("Multiple cell definitions in form: %s" % str(cells))

    # Check that no Coefficient or Argument instance have the same
    # count unless they are the same
    coefficients = {}
    arguments = {}
    for e in iter_expressions(form):
        for f in traverse_unique_terminals(e):
            if isinstance(f, Coefficient):
                c = f.count()
                if c in coefficients:
                    g = coefficients[c]
                    if f is not g:
                        errors.append("Found different Coefficients with " +
                                      "same count: %s and %s." % (repr(f),
                                                                  repr(g)))
                else:
                    coefficients[c] = f

            elif isinstance(f, Argument):
                n = f.number()
                p = f.part()
                if (n, p) in arguments:
                    g = arguments[(n, p)]
                    if f is not g:
                        if n == 0:
                            msg = "TestFunctions"
                        elif n == 1:
                            msg = "TrialFunctions"
                        else:
                            msg = "Arguments with same number and part"
                        msg = "Found different %s: %s and %s." % (msg, repr(f), repr(g))
                        errors.append(msg)
                else:
                    arguments[(n, p)] = f

    # Check that all integrands are scalar
    for expression in iter_expressions(form):
        if not is_true_ufl_scalar(expression):
            errors.append("Found non-scalar integrand expression: %s\n" %
                          ufl_err_str(expression))

    # Check that restrictions are permissible
    for integral in form.integrals():
        # Only allow restrictions on interior facet integrals and
        # surface measures
        if integral.integral_type().startswith("interior_facet"):
            check_restrictions(integral.integrand(), True)
        else:
            check_restrictions(integral.integrand(), False)

    # Raise exception with all error messages
    # TODO: Return errors list instead, need to collect messages from
    # all validations above first.
    if errors:
        final_msg = 'Found errors in validation of form:\n%s' % '\n\n'.join(errors)
        error(final_msg)
Exemplo n.º 20
0
def extract_classes(a):
    """Build a set of all unique Expr subclasses used in a.
    The argument a can be a Form, Integral or Expr."""
    return set(o._uflclass for e in iter_expressions(a) \
                        for o in post_traversal(e))
Exemplo n.º 21
0
def extract_classes(a):
    """Build a set of all unique Expr subclasses used in a.
    The argument a can be a Form, Integral or Expr."""
    return set(o._uflclass for e in iter_expressions(a) \
                        for o in post_traversal(e))