def extract_type(a, ufl_type): """Build a set of all objects of class ufl_type found in a. The argument a can be a Form, Integral or Expr.""" if issubclass(ufl_type, Terminal): return set(o for e in iter_expressions(a) \ for o in traverse_terminals(e) \ if isinstance(o, ufl_type)) return set(o for e in iter_expressions(a) \ for o in post_traversal(e) \ if isinstance(o, ufl_type))
def extract_type(a, ufl_type): """Build a set of all objects of class ufl_type found in a. The argument a can be a Form, Integral or Expr.""" if issubclass(ufl_type, Terminal): return set(o for e in iter_expressions(a) \ for o in traverse_terminals(e) \ if isinstance(o, ufl_type)) return set(o for e in iter_expressions(a) \ for o in post_traversal(e) \ if isinstance(o, ufl_type))
def extract_type(a, ufl_type): """Build a set of all objects of class ufl_type found in a. The argument a can be a Form, Integral or Expr.""" if issubclass(ufl_type, Terminal): # Optimization return set(o for e in iter_expressions(a) for o in traverse_unique_terminals(e) if isinstance(o, ufl_type)) else: return set(o for e in iter_expressions(a) for o in unique_pre_traversal(e) if isinstance(o, ufl_type))
def has_type(a, ufl_types): """Check if any class from ufl_types is found in a. The argument a can be a Form, Integral or Expr.""" if issubclass(ufl_types, Expr): ufl_types = (ufl_types, ) if all(issubclass(ufl_type, Terminal) for ufl_type in ufl_types): return any(isinstance(o, ufl_types) \ for e in iter_expressions(a) \ for o in traverse_terminals(e)) return any(isinstance(o, ufl_types) \ for e in iter_expressions(a) \ for o in post_traversal(e))
def has_type(a, ufl_types): """Check if any class from ufl_types is found in a. The argument a can be a Form, Integral or Expr.""" if issubclass(ufl_types, Expr): ufl_types = (ufl_types,) if all(issubclass(ufl_type, Terminal) for ufl_type in ufl_types): return any(isinstance(o, ufl_types) \ for e in iter_expressions(a) \ for o in traverse_terminals(e)) return any(isinstance(o, ufl_types) \ for e in iter_expressions(a) \ for o in post_traversal(e))
def _extract_arguments(form): # This is a copy of extract_type in ufl.algorithms.analysis # without wrapping the result in a set return [ o for e in iter_expressions(form) for o in traverse_unique_terminals(e) if isinstance(o, Argument) ]
def has_exact_type(a, ufl_type): """Return if an object of class ufl_type can be found in a. The argument a can be a Form, Integral or Expr.""" tc = ufl_type._ufl_typecode_ if issubclass(ufl_type, Terminal): # Optimization traversal = traverse_unique_terminals else: traversal = unique_pre_traversal return any(o._ufl_typecode_ == tc for e in iter_expressions(a) for o in traversal(e))
def form_iterator(form, iterator_type="nodes"): assert iterator_type in ("nodes", "integrals") if iterator_type == "nodes": for integral in form.integrals(): for expression in iter_expressions(integral): for node in pre_traversal(expression): # pre_traversal algorithms guarantees that subsolutions are processed before solutions yield node elif iterator_type == "integrals": for integral in form.integrals(): yield integral else: raise ValueError("Invalid iterator type")
def _extract_variables(a): """Build a list of all Variable objects in a, which can be a Form, Integral or Expr. The ordering in the list obeys dependency order.""" handled = set() variables = [] for e in iter_expressions(a): for o in unique_post_traversal(e): if isinstance(o, Variable): expr, label = o.ufl_operands if label not in handled: variables.append(o) handled.add(label) return variables
def is_multilinear(form): "Check if form is multilinear in arguments." # An attempt at implementing is_multilinear using extract_argument_dependencies. # TODO: This has some false negatives for "multiple configurations". (Does it still? Needs testing!) # TODO: FFC probably needs a variant of this which checks for some sorts of linearity # in Coefficients as well, this should be a fairly simple extension of the current algorithm. try: for e in iter_expressions(form): deps = extract_argument_dependencies(e) nargs = [len(d) for d in deps] if len(nargs) == 0: debug("This form is a functional.") if len(nargs) == 1: debug("This form is linear in %d arguments." % nargs[0]) if len(nargs) > 1: warning("This form has more than one argument "\ "'configuration', it has terms that are linear in %s "\ "arguments respectively." % str(nargs)) except NotMultiLinearException, msg: warning("Form is not multilinear, the offending term is: %s" % msg) return False
def is_multilinear(form): "Check if form is multilinear in arguments." # An attempt at implementing is_multilinear using extract_argument_dependencies. # TODO: This has some false negatives for "multiple configurations". (Does it still? Needs testing!) # TODO: FFC probably needs a variant of this which checks for some sorts of linearity # in Coefficients as well, this should be a fairly simple extension of the current algorithm. try: for e in iter_expressions(form): deps = extract_argument_dependencies(e) nargs = [len(d) for d in deps] if len(nargs) == 0: debug("This form is a functional.") if len(nargs) == 1: debug("This form is linear in %d arguments." % nargs[0]) if len(nargs) > 1: warning("This form has more than one argument "\ "'configuration', it has terms that are linear in %s "\ "arguments respectively." % str(nargs)) except NotMultiLinearException, msg: warning("Form is not multilinear, the offending term is: %s" % msg) return False
def extract_terminals(a): "Build a set of all Terminal objects in a." return set(o for e in iter_expressions(a) \ for o in post_traversal(e) \ if isinstance(o, Terminal))
def __unused__extract_classes(a): """Build a set of all unique Expr subclasses used in a. The argument a can be a Form, Integral or Expr.""" return set(o._ufl_class_ for e in iter_expressions(a) for o in unique_pre_traversal(e))
def validate_form(form): # TODO: Can we make this return a list of errors instead of raising exception? """Performs all implemented validations on a form. Raises exception if something fails.""" errors = [] warnings = [] if not isinstance(form, Form): msg = "Validation failed, not a Form:\n%s" % repr(form) error(msg) #errors.append(msg) #return errors # FIXME: Add back check for multilinearity # Check that form is multilinear #if not is_multilinear(form): # errors.append("Form is not multilinear in arguments.") # FIXME DOMAIN: Add check for consistency between domains somehow domains = set(t.domain() for e in iter_expressions(form) for t in traverse_terminals(e)) - set((None,)) if not domains: errors.append("Missing domain definition in form.") top_domains = set(dom.top_domain() for dom in domains if dom is not None) if not top_domains: errors.append("Missing domain definition in form.") elif len(top_domains) > 1: warnings.append("Multiple top domain definitions in form: %s" % str(top_domains)) # Check that cell is the same everywhere cells = set(dom.cell() for dom in top_domains) - set((None,)) if not cells: errors.append("Missing cell definition in form.") elif len(cells) > 1: errors.append("Multiple cell definitions in form: %s" % str(cells)) # Check that no Coefficient or Argument instance # have the same count unless they are the same coefficients = {} arguments = {} for e in iter_expressions(form): for f in traverse_terminals(e): if isinstance(f, Coefficient): c = f.count() if c in coefficients: g = coefficients[c] if not f is g: errors.append("Found different Coefficients with " + \ "same count: %s and %s." % (repr(f), repr(g))) else: coefficients[c] = f elif isinstance(f, Argument): c = f.count() if c in arguments: g = arguments[c] if not f is g: if c == -2: msg = "TestFunctions" elif c == -1: msg = "TrialFunctions" else: msg = "Arguments with same count" msg = "Found different %s: %s and %s." % (msg, repr(f), repr(g)) errors.append(msg) else: arguments[c] = f # Check that all integrands are scalar for expression in iter_expressions(form): if not is_true_ufl_scalar(expression): errors.append("Found non-scalar integrand expression:\n%s\n%s" % \ (str(expression), repr(expression))) # Check that restrictions are permissible for integral in form.integrals(): # Only allow restricitions on interior facet integrals and surface measures if integral.measure().domain_type() in (Measure.INTERIOR_FACET, Measure.SURFACE): check_restrictions(integral.integrand(), True) else: check_restrictions(integral.integrand(), False) # Raise exception with all error messages # TODO: Return errors list instead, need to collect messages from all validations above first. if errors: final_msg = 'Found errors in validation of form:\n%s' % '\n\n'.join(errors) error(final_msg)
def separate(self): class _SeparatedParametrizedForm_Replacer(Transformer): def __init__(self, mapping): Transformer.__init__(self) self.mapping = mapping def operator(self, e, *ops): if e in self.mapping: return self.mapping[e] else: return e._ufl_expr_reconstruct_(*ops) def terminal(self, e): return self.mapping.get(e, e) logger.log(DEBUG, "*** SEPARATE FORM COEFFICIENTS ***") logger.log(DEBUG, "1. Extract coefficients") integral_to_coefficients = dict() for integral in self._form.integrals(): logger.log( DEBUG, "\t Currently on integrand " + str(integral.integrand())) self._coefficients.append(list()) # of ParametrizedExpression for e in iter_expressions(integral): logger.log(DEBUG, "\t\t Expression " + str(e)) pre_traversal_e = [n for n in pre_traversal(e)] tree_nodes_skip = [False for _ in pre_traversal_e] for (n_i, n) in enumerate(pre_traversal_e): if not tree_nodes_skip[n_i]: # Skip expressions which are trivially non parametrized if isinstance(n, Argument): logger.log( DEBUG, "\t\t Node " + str(n) + " is skipped because it is an Argument") continue elif isinstance(n, Constant): logger.log( DEBUG, "\t\t Node " + str(n) + " is skipped because it is a Constant") continue elif isinstance(n, MultiIndex): logger.log( DEBUG, "\t\t Node " + str(n) + " is skipped because it is a MultiIndex") continue # Skip all expressions with at least one leaf which is an Argument for t in traverse_terminals(n): if isinstance(t, Argument): logger.log( DEBUG, "\t\t Node " + str(n) + " is skipped because it contains an Argument" ) break else: # not broken logger.log( DEBUG, "\t\t Node " + str(n) + " and its descendants are being analyzed for non-parametrized check" ) # Make sure to skip all descendants of this node in the outer loop # Note that a map with key set to the expression is not enough to # mark the node as visited, since the same expression may appear # on different sides of the tree pre_traversal_n = [d for d in pre_traversal(n)] for (d_i, d) in enumerate(pre_traversal_n): assert d == pre_traversal_e[ n_i + d_i] # make sure that we are marking the right node tree_nodes_skip[n_i + d_i] = True # We might be able to strip any (non-parametrized) expression out all_candidates = list() internal_tree_nodes_skip = [ False for _ in pre_traversal_n ] for (d_i, d) in enumerate(pre_traversal_n): if not internal_tree_nodes_skip[d_i]: # Skip all expressions where at least one leaf is not parametrized for t in traverse_terminals(d): if isinstance(t, BaseExpression): if wrapping.is_pull_back_expression( t ) and not wrapping.is_pull_back_expression_parametrized( t): logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " causes the non-parametrized check to break because it contains a non-parametrized pulled back expression" ) break else: parameters = t._parameters if "mu_0" not in parameters: logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " causes the non-parametrized check to break because it contains a non-parametrized expression" ) break elif isinstance(t, Constant): logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " causes the non-parametrized check to break because it contains a constant" ) break elif isinstance( t, GeometricQuantity ) and not isinstance( t, FacetNormal ) and self._strict: logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " causes the non-parametrized check to break because it contains a geometric quantity and strict mode is on" ) break elif wrapping.is_problem_solution_type( t): if not wrapping.is_problem_solution( t ) and not wrapping.is_problem_solution_dot( t): logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " causes the non-parametrized check to break because it contains a non-parametrized function" ) break elif self._strict: # solutions are not allowed, break if wrapping.is_problem_solution( t): ( _, component, solution ) = wrapping.solution_identify_component( t) problem = get_problem_from_solution( solution) logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " causes the non-parametrized check to break because it contains the solution of " + problem.name() + " (exact problem decorator: " + str( hasattr( problem, "__is_exact__" )) + ", component: " + str(component) + ") and strict mode is on" ) break elif wrapping.is_problem_solution_dot( t): ( _, component, solution_dot ) = wrapping.solution_dot_identify_component( t) problem = get_problem_from_solution_dot( solution_dot) logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " causes the non-parametrized check to break because it contains the solution_dot of " + problem.name() + " (exact problem decorator: " + str( hasattr( problem, "__is_exact__" )) + ", component: " + str(component) + ") and strict mode is on" ) else: raise RuntimeError( "Unidentified solution found" ) else: at_least_one_expression_or_solution = False for t in traverse_terminals(d): if isinstance( t, BaseExpression ): # which is parametrized, because previous for loop was not broken at_least_one_expression_or_solution = True logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " is a candidate after non-parametrized check because it contains the parametrized expression " + str(t)) break elif wrapping.is_problem_solution_type( t): if wrapping.is_problem_solution( t): at_least_one_expression_or_solution = True ( _, component, solution ) = wrapping.solution_identify_component( t) problem = get_problem_from_solution( solution) logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " is a candidate after non-parametrized check because it contains the solution of " + problem.name() + " (exact problem decorator: " + str( hasattr( problem, "__is_exact__" )) + ", component: " + str(component) + ")") break elif wrapping.is_problem_solution_dot( t): at_least_one_expression_or_solution = True ( _, component, solution_dot ) = wrapping.solution_dot_identify_component( t) problem = get_problem_from_solution_dot( solution_dot) logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " is a candidate after non-parametrized check because it contains the solution_dot of " + problem.name() + " (exact problem decorator: " + str( hasattr( problem, "__is_exact__" )) + ", component: " + str(component) + ")") break if at_least_one_expression_or_solution: all_candidates.append(d) pre_traversal_d = [ q for q in pre_traversal(d) ] for (q_i, q) in enumerate( pre_traversal_d): assert q == pre_traversal_n[ d_i + q_i] # make sure that we are marking the right node internal_tree_nodes_skip[ d_i + q_i] = True else: logger.log( DEBUG, "\t\t\t Descendant node " + str(d) + " has not passed the non-parametrized because it is not a parametrized expression or a solution" ) # Evaluate candidates if len( all_candidates ) == 0: # the whole expression was actually non-parametrized logger.log( DEBUG, "\t\t Node " + str(n) + " is skipped because it is a non-parametrized coefficient" ) continue elif len( all_candidates ) == 1: # the whole expression was actually parametrized logger.log( DEBUG, "\t\t Node " + str(n) + " will be accepted because it is a non-parametrized coefficient" ) pass else: # part of the expression was not parametrized, and separating the non parametrized part may result in more than one coefficient if self._strict: # non parametrized coefficients are not allowed, so split the expression logger.log( DEBUG, "\t\t\t Node " + str(n) + " will be accepted because it is a non-parametrized coefficient with more than one candidate. It will be split because strict mode is on. Its split coefficients are " + ", ".join([ str(c) for c in all_candidates ])) else: # non parametrized coefficients are allowed, so go on with the whole expression logger.log( DEBUG, "\t\t\t Node " + str(n) + " will be accepted because it is a non-parametrized coefficient with more than one candidate. It will not be split because strict mode is off. Splitting it would have resulted in more than one coefficient, namely " + ", ".join([ str(c) for c in all_candidates ])) all_candidates = [n] # Add the coefficient(s) for candidate in all_candidates: def preprocess_candidate(candidate): if isinstance(candidate, Indexed): assert len( candidate.ufl_operands) == 2 assert isinstance( candidate.ufl_operands[1], MultiIndex) if all([ isinstance( index, FixedIndex) for index in candidate. ufl_operands[1].indices() ]): logger.log( DEBUG, "\t\t\t Preprocessed descendant node " + str(candidate) + " as an Indexed expression with fixed indices, resulting in a candidate " + str(candidate) + " of type " + str(type(candidate))) return candidate # no further preprocessing needed else: logger.log( DEBUG, "\t\t\t Preprocessed descendant node " + str(candidate) + " as an Indexed expression with at least one mute index, resulting in a candidate " + str(candidate. ufl_operands[0]) + " of type " + str( type(candidate. ufl_operands[0]))) return preprocess_candidate( candidate.ufl_operands[0]) elif isinstance(candidate, IndexSum): assert len( candidate.ufl_operands) == 2 assert isinstance( candidate.ufl_operands[1], MultiIndex) assert all([ isinstance(index, MuteIndex) for index in candidate. ufl_operands[1].indices() ]) logger.log( DEBUG, "\t\t\t Preprocessed descendant node " + str(candidate) + " as an IndexSum expression, resulting in a candidate " + str(candidate.ufl_operands[0]) + " of type " + str( type(candidate. ufl_operands[0]))) return preprocess_candidate( candidate.ufl_operands[0]) elif isinstance(candidate, ListTensor): candidates = set([ preprocess_candidate(component) for component in candidate.ufl_operands ]) if len(candidates) == 1: preprocessed_candidate = candidates.pop( ) logger.log( DEBUG, "\t\t\t Preprocessed descendant node " + str(candidate) + " as an ListTensor expression with a unique preprocessed component, resulting in a candidate " + str(preprocessed_candidate) + " of type " + str( type( preprocessed_candidate ))) return preprocess_candidate( preprocessed_candidate) else: at_least_one_mute_index = False candidates_from_components = list( ) for component in candidates: assert isinstance( component, (ComponentTensor, Indexed)) assert len( component.ufl_operands ) == 2 assert isinstance( component. ufl_operands[1], MultiIndex) if not all([ isinstance( index, FixedIndex) for index in component. ufl_operands[1]. indices() ]): at_least_one_mute_index = True candidates_from_components.append( preprocess_candidate( component. ufl_operands[0])) if at_least_one_mute_index: candidates_from_components = set( candidates_from_components ) assert len( candidates_from_components ) == 1 preprocessed_candidate = candidates_from_components.pop( ) logger.log( DEBUG, "\t\t\t Preprocessed descendant node " + str(candidate) + " as an ListTensor expression with multiple preprocessed components with at least one mute index, resulting in a candidate " + str(preprocessed_candidate ) + " of type " + str( type( preprocessed_candidate ))) return preprocess_candidate( preprocessed_candidate) else: logger.log( DEBUG, "\t\t\t Preprocessed descendant node " + str(candidate) + " as an ListTensor expression with multiple preprocessed components with fixed indices, resulting in a candidate " + str(candidate) + " of type " + str(type(candidate))) return candidate # no further preprocessing needed else: logger.log( DEBUG, "\t\t\t No preprocessing required for descendant node " + str(candidate) + " as a coefficient of type " + str(type(candidate))) return candidate preprocessed_candidate = preprocess_candidate( candidate) if preprocessed_candidate not in self._coefficients[ -1]: self._coefficients[-1].append( preprocessed_candidate) logger.log( DEBUG, "\t\t\t Accepting descendant node " + str(preprocessed_candidate) + " as a coefficient of type " + str(type(preprocessed_candidate))) else: logger.log( DEBUG, "\t\t Node " + str(n) + " to be skipped because it is a descendant of a coefficient which has already been detected" ) if len(self._coefficients[-1] ) == 0: # then there were no coefficients to extract logger.log(DEBUG, "\t There were no coefficients to extract") self._coefficients.pop( ) # remove the (empty) element that was added to possibly store coefficients else: logger.log(DEBUG, "\t Extracted coefficients are:") for c in self._coefficients[-1]: logger.log(DEBUG, "\t\t" + str(c)) integral_to_coefficients[integral] = self._coefficients[-1] logger.log(DEBUG, "2. Prepare placeholders and forms with placeholders") for integral in self._form.integrals(): # Prepare measure for the new form (from firedrake/mg/ufl_utils.py) measure = Measure(integral.integral_type(), domain=integral.ufl_domain(), subdomain_id=integral.subdomain_id(), subdomain_data=integral.subdomain_data(), metadata=integral.metadata()) if integral not in integral_to_coefficients: logger.log( DEBUG, "\t Adding form for integrand " + str(integral.integrand()) + " to unchanged forms") self._form_unchanged.append(integral.integrand() * measure) else: logger.log( DEBUG, "\t Preparing form with placeholders for integrand " + str(integral.integrand())) self._placeholders.append(list()) # of Constants placeholders_dict = dict() for c in integral_to_coefficients[integral]: self._placeholders[-1].append( Constant(self._NaN * ones(c.ufl_shape))) placeholders_dict[c] = self._placeholders[-1][-1] logger.log( DEBUG, "\t\t " + str(placeholders_dict[c]) + " is the placeholder for " + str(c)) replacer = _SeparatedParametrizedForm_Replacer( placeholders_dict) new_integrand = apply_transformer(integral.integrand(), replacer) self._form_with_placeholders.append(new_integrand * measure) logger.log( DEBUG, "3. Assert that there are no parametrized expressions left") for form in self._form_with_placeholders: for integral in form.integrals(): for e in pre_traversal(integral.integrand()): if isinstance(e, BaseExpression): assert not ( wrapping.is_pull_back_expression(e) and wrapping. is_pull_back_expression_parametrized(e) ), "Form " + str( integral ) + " still contains a parametrized pull back expression" parameters = e._parameters assert "mu_0" not in parameters, "Form " + str( integral ) + " still contains a parametrized expression" logger.log(DEBUG, "4. Prepare coefficients hash codes") for addend in self._coefficients: self._placeholder_names.append(list()) # of string for factor in addend: self._placeholder_names[-1].append( wrapping.expression_name(factor)) logger.log(DEBUG, "5. Assert list length consistency") assert len(self._coefficients) == len(self._placeholders) assert len(self._coefficients) == len(self._placeholder_names) for (c, p, pn) in zip(self._coefficients, self._placeholders, self._placeholder_names): assert len(c) == len(p) assert len(c) == len(pn) assert len(self._coefficients) == len(self._form_with_placeholders) logger.log(DEBUG, "*** DONE - SEPARATE FORM COEFFICIENTS - DONE ***") logger.log(DEBUG, "")
def expression_iterator(expression): for subexpression in iter_expressions(expression): for node in pre_traversal( subexpression ): # pre_traversal algorithms guarantees that subsolutions are processed before solutions yield node
def validate_form( form ): # TODO: Can we make this return a list of errors instead of raising exception? """Performs all implemented validations on a form. Raises exception if something fails.""" errors = [] if not isinstance(form, Form): msg = "Validation failed, not a Form:\n%s" % ufl_err_str(form) error(msg) # errors.append(msg) # return errors # FIXME: There's a bunch of other checks we should do here. # FIXME: Add back check for multilinearity # Check that form is multilinear # if not is_multilinear(form): # errors.append("Form is not multilinear in arguments.") # FIXME DOMAIN: Add check for consistency between domains somehow domains = set(t.ufl_domain() for e in iter_expressions(form) for t in traverse_unique_terminals(e)) - {None} if not domains: errors.append("Missing domain definition in form.") # Check that cell is the same everywhere cells = set(dom.ufl_cell() for dom in domains) - {None} if not cells: errors.append("Missing cell definition in form.") elif len(cells) > 1: errors.append("Multiple cell definitions in form: %s" % str(cells)) # Check that no Coefficient or Argument instance have the same # count unless they are the same coefficients = {} arguments = {} for e in iter_expressions(form): for f in traverse_unique_terminals(e): if isinstance(f, Coefficient): c = f.count() if c in coefficients: g = coefficients[c] if f is not g: errors.append("Found different Coefficients with " + "same count: %s and %s." % (repr(f), repr(g))) else: coefficients[c] = f elif isinstance(f, Argument): n = f.number() p = f.part() if (n, p) in arguments: g = arguments[(n, p)] if f is not g: if n == 0: msg = "TestFunctions" elif n == 1: msg = "TrialFunctions" else: msg = "Arguments with same number and part" msg = "Found different %s: %s and %s." % (msg, repr(f), repr(g)) errors.append(msg) else: arguments[(n, p)] = f # Check that all integrands are scalar for expression in iter_expressions(form): if not is_true_ufl_scalar(expression): errors.append("Found non-scalar integrand expression: %s\n" % ufl_err_str(expression)) # Check that restrictions are permissible for integral in form.integrals(): # Only allow restrictions on interior facet integrals and # surface measures if integral.integral_type().startswith("interior_facet"): check_restrictions(integral.integrand(), True) else: check_restrictions(integral.integrand(), False) # Raise exception with all error messages # TODO: Return errors list instead, need to collect messages from # all validations above first. if errors: final_msg = 'Found errors in validation of form:\n%s' % '\n\n'.join( errors) error(final_msg)
def extract_terminals(a): "Build a set of all Terminal objects in a." return set(o for e in iter_expressions(a) \ for o in post_traversal(e) \ if isinstance(o, Terminal))
def validate_form(form): # TODO: Can we make this return a list of errors instead of raising exception? """Performs all implemented validations on a form. Raises exception if something fails.""" errors = [] if not isinstance(form, Form): msg = "Validation failed, not a Form:\n%s" % ufl_err_str(form) error(msg) # errors.append(msg) # return errors # FIXME: There's a bunch of other checks we should do here. # FIXME: Add back check for multilinearity # Check that form is multilinear # if not is_multilinear(form): # errors.append("Form is not multilinear in arguments.") # FIXME DOMAIN: Add check for consistency between domains somehow domains = set(t.ufl_domain() for e in iter_expressions(form) for t in traverse_unique_terminals(e)) - {None} if not domains: errors.append("Missing domain definition in form.") # Check that cell is the same everywhere cells = set(dom.ufl_cell() for dom in domains) - {None} if not cells: errors.append("Missing cell definition in form.") elif len(cells) > 1: errors.append("Multiple cell definitions in form: %s" % str(cells)) # Check that no Coefficient or Argument instance have the same # count unless they are the same coefficients = {} arguments = {} for e in iter_expressions(form): for f in traverse_unique_terminals(e): if isinstance(f, Coefficient): c = f.count() if c in coefficients: g = coefficients[c] if f is not g: errors.append("Found different Coefficients with " + "same count: %s and %s." % (repr(f), repr(g))) else: coefficients[c] = f elif isinstance(f, Argument): n = f.number() p = f.part() if (n, p) in arguments: g = arguments[(n, p)] if f is not g: if n == 0: msg = "TestFunctions" elif n == 1: msg = "TrialFunctions" else: msg = "Arguments with same number and part" msg = "Found different %s: %s and %s." % (msg, repr(f), repr(g)) errors.append(msg) else: arguments[(n, p)] = f # Check that all integrands are scalar for expression in iter_expressions(form): if not is_true_ufl_scalar(expression): errors.append("Found non-scalar integrand expression: %s\n" % ufl_err_str(expression)) # Check that restrictions are permissible for integral in form.integrals(): # Only allow restrictions on interior facet integrals and # surface measures if integral.integral_type().startswith("interior_facet"): check_restrictions(integral.integrand(), True) else: check_restrictions(integral.integrand(), False) # Raise exception with all error messages # TODO: Return errors list instead, need to collect messages from # all validations above first. if errors: final_msg = 'Found errors in validation of form:\n%s' % '\n\n'.join(errors) error(final_msg)
def extract_classes(a): """Build a set of all unique Expr subclasses used in a. The argument a can be a Form, Integral or Expr.""" return set(o._uflclass for e in iter_expressions(a) \ for o in post_traversal(e))
def extract_classes(a): """Build a set of all unique Expr subclasses used in a. The argument a can be a Form, Integral or Expr.""" return set(o._uflclass for e in iter_expressions(a) \ for o in post_traversal(e))