def _process_list_list_equation(lhs_list: List[_NonListExpr], rhs_list: List[_NonListExpr], context: _UnificationContext): if not (len(lhs_list) == 1 and (isinstance(lhs_list[0], str) or (isinstance(lhs_list[0], ListExpansion) and isinstance(lhs_list[0].expr, str)))): lhs_list, rhs_list = rhs_list, lhs_list if len(lhs_list) == 1 and (isinstance(lhs_list[0], str) or (isinstance(lhs_list[0], ListExpansion) and isinstance(lhs_list[0].expr, str))): [lhs] = lhs_list _process_var_expr_equation(lhs, rhs_list, context) return if len(lhs_list) == 1 and len(rhs_list) == 1 and not isinstance( lhs_list[0], ListExpansion) and not isinstance( rhs_list[0], ListExpansion): [lhs] = lhs_list [rhs] = rhs_list assert not isinstance(lhs, str) assert not isinstance(rhs, str) _process_term_term_equation(lhs, rhs, context) return removed_something = False while (lhs_list and rhs_list and ((not isinstance(lhs_list[0], ListExpansion) and not isinstance(rhs_list[0], ListExpansion)) or (isinstance(lhs_list[0], ListExpansion) and isinstance(rhs_list[0], ListExpansion) and isinstance( lhs_list[0].expr, str) and isinstance(rhs_list[0].expr, str) and lhs_list[0].expr == rhs_list[0].expr))): # We can match the first element. context.expr_expr_equations.append(([lhs_list[0]], [rhs_list[0]])) lhs_list = lhs_list[1:] rhs_list = rhs_list[1:] removed_something = True while (lhs_list and rhs_list and ((not isinstance(lhs_list[-1], ListExpansion) and not isinstance(rhs_list[-1], ListExpansion)) or (isinstance(lhs_list[-1], ListExpansion) and isinstance(rhs_list[-1], ListExpansion) and isinstance(lhs_list[-1].expr, str) and isinstance(rhs_list[-1].expr, str) and lhs_list[-1].expr == rhs_list[-1].expr))): # We can match the last element. context.expr_expr_equations.append(([lhs_list[-1]], [rhs_list[-1]])) lhs_list = lhs_list[:-1] rhs_list = rhs_list[:-1] removed_something = True if not lhs_list and not rhs_list: # We already matched everything. return strategy = context.strategy if not any(isinstance(lhs, ListExpansion) for lhs in lhs_list) \ and not any(isinstance(rhs, ListExpansion) for rhs in rhs_list): # There are no list expansions but one of the two sides still has unmatched elems. if context.expanded_non_syntactically_comparable_expr: raise UnificationAmbiguousException( 'Deduced %s = %s, which differ in length and have no list vars\nAfter expanding a non-syntactically-comparable expr:\n%s' % (exprs_to_string( strategy, lhs_list), exprs_to_string(strategy, rhs_list), expr_to_string( strategy, context.expanded_non_syntactically_comparable_expr))) else: raise UnificationFailedException( 'Deduced %s = %s, which differ in length and have no list vars' % (exprs_to_string( strategy, lhs_list), exprs_to_string(strategy, rhs_list))) if removed_something: # We put back the trimmed lists and re-process them from the start (we might have a var-expr or term-term # equation now). context.expr_expr_equations.append((lhs_list, rhs_list)) return if not rhs_list: rhs_list, lhs_list = lhs_list, rhs_list if not lhs_list: for arg in rhs_list: if isinstance(arg, ListExpansion) and isinstance(arg.expr, str): # If we always pick this branch in the loop, it's an equality of the form: # [] = [*l1, ... *ln] context.expr_expr_equations.append(([arg], [])) else: if context.expanded_non_syntactically_comparable_expr: raise UnificationAmbiguousException() else: raise UnificationFailedException() return # E.g. in these cases: # ['x', 'y', *l1] = [*l2, 'z'] # [*l1, *l2] = [*l3, *l4] # ['x', *l1] = [*l2, *l3] raise UnificationAmbiguousException('Deduced %s = %s' % (exprs_to_string( strategy, lhs_list), exprs_to_string(strategy, rhs_list)))
def _process_var_expr_equation(lhs: Union[str, ListExpansion], rhs_list: List[_NonListExpr], context: _UnificationContext): if len(rhs_list) == 1: [rhs] = rhs_list if isinstance(lhs, str) and isinstance(rhs, str) and lhs == rhs: return if (isinstance(lhs, ListExpansion) and isinstance(rhs, ListExpansion) and isinstance(lhs.expr, str) and isinstance(rhs.expr, str) and lhs.expr == rhs.expr): return if isinstance(lhs, str) and lhs in context.var_expr_equations: context.expr_expr_equations.append( ([context.var_expr_equations[lhs]], rhs_list)) return if isinstance(lhs, str) and lhs in context.context_var_expr_equations: context.expr_expr_equations.append( ([context.context_var_expr_equations[lhs]], rhs_list)) return if isinstance( lhs, ListExpansion) and lhs.expr in context.expanded_var_expr_equations: context.expr_expr_equations.append( (context.expanded_var_expr_equations[lhs.expr], rhs_list)) return assert not (isinstance(lhs, ListExpansion) and lhs.expr in context.context_var_expr_equations) if len(rhs_list) == 1 and isinstance(rhs_list[0], str): if rhs_list[0] in context.var_expr_equations: context.expr_expr_equations.append( ([lhs], [context.var_expr_equations[rhs_list[0]]])) return if rhs_list[0] in context.context_var_expr_equations: context.expr_expr_equations.append( ([lhs], [context.context_var_expr_equations[rhs_list[0]]])) return if len(rhs_list) == 1 and isinstance( rhs_list[0], ListExpansion) and isinstance(rhs_list[0].expr, str): if rhs_list[0].expr in context.expanded_var_expr_equations: context.expr_expr_equations.append( ([lhs], context.expanded_var_expr_equations[rhs_list[0].expr])) return assert rhs_list[0].expr not in context.context_var_expr_equations if len(rhs_list) != 1 and not isinstance(lhs, ListExpansion) and not any( isinstance(expr, ListExpansion) for expr in rhs_list): # Different number of args and no list expansion to consider. strategy = context.strategy if context.expanded_non_syntactically_comparable_expr: raise UnificationAmbiguousException( 'Found expr lists of different lengths with no list exprs: %s vs %s\nAfter expanding a non-syntactically-comparable expr:\n%s' % (exprs_to_string(strategy, [lhs]), exprs_to_string(strategy, rhs_list), expr_to_string( strategy, context.expanded_non_syntactically_comparable_expr))) else: raise UnificationFailedException( 'Found expr lists of different lengths with no list exprs: %s vs %s' % (exprs_to_string( strategy, [lhs]), exprs_to_string(strategy, rhs_list))) if isinstance(lhs, str): for rhs in rhs_list: _occurence_check(lhs, rhs, context) [rhs] = rhs_list context.var_expr_equations[lhs] = rhs else: assert isinstance(lhs, ListExpansion) for rhs in rhs_list: _occurence_check(lhs.expr, rhs, context) if len(rhs_list) == 1 and isinstance(rhs_list[0], ListExpansion): context.var_expr_equations[lhs.expr] = rhs_list[0].expr else: context.expanded_var_expr_equations[lhs.expr] = rhs_list
def _occurence_check(var1: str, expr1: _Expr, context: _UnificationContext): strategy = context.strategy if isinstance(expr1, str): var_expr_pairs_to_check = [(var1, expr1, None)] elif isinstance(expr1, ListExpansion): if not context.expanded_non_syntactically_comparable_expr: context.expanded_non_syntactically_comparable_expr = expr1 var_expr_pairs_to_check = [ (var1, expr1, context.expanded_non_syntactically_comparable_expr) ] else: if not context.expanded_non_syntactically_comparable_expr and not strategy.equality_requires_syntactical_equality( expr1): context.expanded_non_syntactically_comparable_expr = expr1 var_expr_pairs_to_check = [ (var1, expr1, context.expanded_non_syntactically_comparable_expr) ] while var_expr_pairs_to_check: var, expr, only_expanded_terms_with_syntactical_equality = var_expr_pairs_to_check.pop( ) if isinstance(expr, str): if var == expr: if context.expanded_non_syntactically_comparable_expr: raise UnificationAmbiguousException( "Ambiguous occurrence check for var %s while checking %s in %s with equations:\n%s\nSince the following non-syntactically-comparable expr has been expanded:\n%s" % (var, var1, expr_to_string(strategy, expr1), { var: expr_to_string(strategy, expr) for var, expr in context.var_expr_equations.items() }, expr_to_string( strategy, context. expanded_non_syntactically_comparable_expr))) else: raise UnificationFailedException( "Failed occurrence check for var %s while checking %s in %s with equations:\n%s" % (var, var1, expr_to_string(strategy, expr1), { var: expr_to_string(strategy, expr) for var, expr in context.var_expr_equations.items() })) if expr in context.var_expr_equations: var_expr_pairs_to_check.append( (var, context.var_expr_equations[expr], only_expanded_terms_with_syntactical_equality)) if expr in context.expanded_var_expr_equations: for elem in context.var_expr_equations[expr]: var_expr_pairs_to_check.append( (var, elem, only_expanded_terms_with_syntactical_equality)) if expr in context.context_var_expr_equations: var_expr_pairs_to_check.append( (var, context.context_var_expr_equations[expr], only_expanded_terms_with_syntactical_equality)) elif isinstance(expr, ListExpansion): var_expr_pairs_to_check.append((var, expr.expr, False)) else: is_term_with_syntactical_equality = strategy.equality_requires_syntactical_equality( expr) for arg in strategy.get_term_args(expr): var_expr_pairs_to_check.append( (var, arg, only_expanded_terms_with_syntactical_equality and is_term_with_syntactical_equality))
def canonicalize(var_expr_equations: Dict[str, _NonListExpr], expanded_var_expr_equations: Dict[str, List[_NonListExpr]], strategy: UnificationStrategyForCanonicalization[TermT]) -> List[Tuple[Union[str, ListExpansion[TermT]], List[_NonListExpr]]]: if not var_expr_equations and not expanded_var_expr_equations: return [] var_expr_equations = var_expr_equations.copy() expanded_var_expr_equations = expanded_var_expr_equations.copy() # A graph that has all variables on the LHS of equations as nodes and an edge var1->var2 if we have the equation # var1=expr and var2 appears in expr. vars_dependency_graph = nx.DiGraph() for lhs, rhs in var_expr_equations.items(): vars_dependency_graph.add_node(lhs) for var in _get_free_variables(rhs, strategy): vars_dependency_graph.add_edge(lhs, var) if isinstance(rhs, str): # This is a var-var equation. We also add an edge for the flipped equation. # That's going to cause a cycle, but we'll deal with the cycle below once we know if any other vars are # part of the cycle. vars_dependency_graph.add_edge(rhs, lhs) for lhs, rhs_list in expanded_var_expr_equations.items(): vars_dependency_graph.add_node(lhs) for rhs_expr in rhs_list: for var in _get_free_variables(rhs_expr, strategy): vars_dependency_graph.add_edge(lhs, var) if len(rhs_list) == 1 and isinstance(rhs_list[0], ListExpansion) and isinstance(rhs_list[0].expr, str): # This is a var-var equation. We also add an edge for the flipped equation. # That's going to cause a cycle, but we'll deal with the cycle below once we know if any other vars are # part of the cycle. vars_dependency_graph.add_edge(rhs_list[0].expr, lhs) for vars_in_connected_component in reversed(list( compute_condensation_in_topological_order(vars_dependency_graph))): vars_in_connected_component = vars_in_connected_component.copy() if len(vars_in_connected_component) == 1: [var] = vars_in_connected_component if var in var_expr_equations: # We can't flip the equation for this var since it's a "var=term" or "var=ListExpansion(...)" equation. assert not isinstance(var_expr_equations[var], str) if not strategy.can_var_be_on_lhs(var): raise CanonicalizationFailedException('Deduced equation that can\'t be flipped with LHS-forbidden var: %s = %s' % ( var, expr_to_string(strategy, var_expr_equations[var]))) elif var in expanded_var_expr_equations: # We can't flip the equation for this var since it's a "ListExpansion(var)=var2" or "ListExpansion(var)=term" equation. assert not (len(expanded_var_expr_equations[var]) == 1 and isinstance(expanded_var_expr_equations[var][0], ListExpansion) and isinstance(expanded_var_expr_equations[var][0].expr, str)) if not strategy.can_var_be_on_lhs(var): raise CanonicalizationFailedException('Deduced equation that can\'t be flipped with LHS-forbidden var: ListExpansion(%s) = %s' % ( var, exprs_to_string(strategy, expanded_var_expr_equations[var]))) else: # This var is just part of a larger term in some other equation. assert not next(vars_dependency_graph.successors(var), None) else: assert len(vars_in_connected_component) > 1 # We have a loop. # If any expression of the loop is a term with syntactic equality, unification would be impossible because # we can deduce var1=expr1 in which expr1 is not just var1 and var1 appears in expr1. # But in this case unify() would have failed. # If any expression of the loop is a term with non-syntactic equality, the unification is ambiguous. for var in vars_in_connected_component: if var in var_expr_equations: if not (isinstance(var_expr_equations[var], str) or (isinstance(var_expr_equations[var], ListExpansion) and isinstance(var_expr_equations[var].expr, str))): raise CanonicalizationFailedException() if var in expanded_var_expr_equations: if not (len(expanded_var_expr_equations[var]) == 1 and (isinstance(expanded_var_expr_equations[var][0], str) or (isinstance(expanded_var_expr_equations[var][0], ListExpansion) and isinstance(expanded_var_expr_equations[var][0].expr, str)))): raise CanonicalizationFailedException() [is_expanded_var] = {var in expanded_var_expr_equations for var in vars_in_connected_component if var in var_expr_equations or var in expanded_var_expr_equations} # So here we can assume that all exprs in the loop are variables, i.e. the loop is of the form # var1=var2=...=varN. So we have a choice of what var to put on the RHS. vars_in_rhs = [var for var in vars_in_connected_component if not strategy.can_var_be_on_lhs(var)] if len(vars_in_rhs) == 0: # Any var would do. We pick the max just to make this function deterministic. rhs_var = max(*vars_in_connected_component) elif len(vars_in_rhs) == 1: # This is the only one we can pick. [rhs_var] = vars_in_rhs else: # We need at least n-1 distinct LHS vars but we don't have enough vars allowed on the LHS. raise CanonicalizationFailedException('Found var equality chain that can\'t be canonicalized due to multiple LHS-forbidden vars: %s' % ', '.join(vars_in_rhs)) # Now we remove all equations defining these vars and the corresponding edges in the graph. for var in vars_in_connected_component: if var in var_expr_equations: del var_expr_equations[var] if var in expanded_var_expr_equations: del expanded_var_expr_equations[var] for successor in list(vars_dependency_graph.successors(var)): vars_dependency_graph.remove_edge(var, successor) # And finally we add the rearranged equations. for var in vars_in_connected_component: if var != rhs_var: if is_expanded_var: expanded_var_expr_equations[var] = [ListExpansion(rhs_var)] else: var_expr_equations[var] = rhs_var vars_dependency_graph.add_edge(var, rhs_var) # Invariant: # assert not any(key in _get_free_variables(value, strategy) # for key in itertools.chain(canonical_var_expr_equations.keys(), canonical_expanded_var_expr_equations.keys()) # for value in itertools.chain(canonical_var_expr_equations.values(), canonical_expanded_var_expr_equations.values())) canonical_var_expr_equations: Dict[str, Union[_NonListExpr, List[_NonListExpr]]] = dict() canonical_expanded_var_expr_equations: Dict[str, Union[_NonListExpr, List[_NonListExpr]]] = dict() for var in reversed(list(nx.lexicographical_topological_sort(vars_dependency_graph))): expr = var_expr_equations.get(var) if expr is not None: expr = strategy.replace_variables_in_expr(expr, canonical_var_expr_equations, canonical_expanded_var_expr_equations) canonical_var_expr_equations[var] = expr assert var not in expanded_var_expr_equations expr_list = expanded_var_expr_equations.get(var) if expr_list is not None: expr_list = _replace_variables_in_exprs(expr_list, canonical_var_expr_equations, canonical_expanded_var_expr_equations, strategy) canonical_expanded_var_expr_equations[var] = expr_list assert isinstance(expr_list, list) assert var not in var_expr_equations return ([(var, expr) for var, expr in canonical_var_expr_equations.items()] + [(ListExpansion(var), expr) for var, expr in expanded_var_expr_equations.items()])