def test_get_variables(): from pymbolic import var f = var("f") x = var("x") from dagrt.utils import get_variables assert get_variables(f(x)) == frozenset(["x"]) assert get_variables(f(t=x)) == frozenset(["x"])
def get_read_variables(self): result = super().get_read_variables() for par in self.parameters: result |= get_variables(par) for par in self.kw_parameters.values(): result |= get_variables(par) return result
def test_get_variables_with_function_symbols(): from pymbolic import var f = var("f") x = var("x") from dagrt.utils import get_variables assert get_variables(f(x), include_function_symbols=True) == \ frozenset(["f", "x"])
def get_read_variables(self): # Variables can be read by: # 1. expressions (except for those in solve_variables) # 2. values in other_params # 3. condition from itertools import chain def flatten(iter_arg): return chain(*list(iter_arg)) variables = super().get_read_variables() variables |= set( flatten(get_variables(expr) for expr in self.expressions)) variables -= set(self.solve_variables) variables |= set( flatten( get_variables(expr) for expr in self.other_params.values())) return variables
def match(template, expression, free_variable_names=None, bound_variable_names=None, pre_match=None): """Attempt to match the free variables found in `template` to terms in `expression`, modulo associativity and commutativity. This implements a one-way unification algorithm, matching free variables in `template` to subexpressions of `expression`. If `free_variable_names` is *None*, then all variables except those in `bound_variable_names` are treated as free. Matches that are already known to hold can be specified in `pre_match`, a map from variable names to subexpressions (or strings representing subexpressions). Return a map from variable names in `free_variable_names` to expressions. """ if isinstance(template, str): template = parse(template) if isinstance(expression, str): expression = parse(expression) if bound_variable_names is None: bound_variable_names = set() if free_variable_names is None: from dagrt.utils import get_variables free_variable_names = get_variables(template, include_function_symbols=True) free_variable_names -= set(bound_variable_names) urecs = None if pre_match is not None: eqns = [] for name, expr in pre_match.items(): if name not in free_variable_names: raise ValueError("'%s' was given in 'pre_match' but is " "not a candidate for matching" % name) if isinstance(expr, str): expr = parse(expr) eqns.append((Variable(name), expr)) from pymbolic.mapper.unifier import UnificationRecord urecs = [UnificationRecord(eqns)] unifier = _ExtendedUnifier(free_variable_names) records = unifier(template, expression, urecs) if len(records) > 1: from warnings import warn warn('Matching\n"{expr}"\nto\n"{template}"\n' "is ambiguous - using first match".format(expr=expression, template=template)) if not records: raise ValueError("Cannot unify expressions.") return {key.name: val for key, val in records[0].equations}
def _add_statement(self, stmt): stmt_id = self.next_statement_id() read_variables = set(stmt.get_read_variables()) written_variables = set(stmt.get_written_variables()) # Add the global execution state as an implicitly read variable. read_variables.add(self._EXECUTION_STATE) # Build the condition attribute. if not self._conditional_expression_stack: condition = True elif len(self._conditional_expression_stack) == 1: condition = self._conditional_expression_stack[0] else: from pymbolic.primitives import LogicalAnd condition = LogicalAnd(tuple(self._conditional_expression_stack)) from dagrt.utils import get_variables read_variables |= get_variables(condition) is_non_assignment = (not isinstance( stmt, (Assign, AssignImplicit, AssignFunctionCall))) # We regard all non-assignments as having potential external side # effects (i.e., writing to EXECUTION_STATE). To keep the global # variables in a well-defined state, ensure that all updates to global # variables have happened before a non-assignment. if is_non_assignment: from dagrt.utils import is_state_variable read_variables |= { var for var in self._seen_var_names if is_state_variable(var) } written_variables.add(self._EXECUTION_STATE) depends_on = set() # Ensure this statement happens after the last write of all the # variables it reads or writes. for var in read_variables | written_variables: writer = self._writer_map.get(var, None) if writer is not None: depends_on.add(writer) # Ensure this statement happens after the last read(s) of the variables # it writes to. for var in written_variables: readers = self._reader_map.get(var, set()) depends_on |= readers # Keep the graph sparse by clearing the readers set. readers.clear() for var in written_variables: self._writer_map[var] = stmt_id for var in read_variables: # reader_map should ignore reads that happen before writes, so # ignore if this statement also reads *var*. if var in written_variables: continue self._reader_map.setdefault(var, set()).add(stmt_id) stmt = stmt.copy(id=stmt_id, condition=condition, depends_on=frozenset(depends_on)) self.statements.append(stmt) self._seen_var_names |= read_variables | written_variables
def get_read_variables(self): return (super().get_read_variables() | get_variables(self.expression) | get_variables(self.time))