Esempio n. 1
0
def apply_factorizations(equations, operands_to_factor, factorization_dict):
    """This function generates new equations by applying factorizations.

    For applying factorizations, a number of rules are applied:
    - Matrices are only factored if they have the ADMITS_FACTORIZATION
      property (this is tested by find_operands_to_factor() ).
    - We never apply different factorizations to different occurrences of
      the same matrix. As an example, if the matrix A shows up twice, we
      will never apply LU to one occurrence and QR to another.
    - Factorization will always be applied to matrices that appear
      immediately inside an inverse. That is, the A in Inverse(A) will be
      factored. A and B in Inverse(Times(A, B)) don't have to be factored.
    - If there is a summand that contains a matrix which is factored, but
      the summand does not contain any occurrences of that matrix within an
      inverse, that matrix will not be factored in this summand. This is
      to make sure that for A+inv(A), the first A is not factored.
    - Some factorizations rule out others. If Cholesky can be applied, LU
      will no be applied. Which factorization can be applied per operand is
      decided in DS_factorizations(). The factorization_dict contains the
      valid factorizations.

    This function applies factorization in all possible combinations that
    obey the rules above.
    """

    # find all occurrences
    all_occurrences = list(find_occurrences(equations, operands_to_factor))

    blocking_products = list(
        find_blocking_products(equations, operands_to_factor))

    # Removing groups (summands) which do not contain any inverted occurrences.
    candidate_occurrences = []
    for oc_group in group_occurrences(all_occurrences):
        if any(oc.type != InverseType.none for oc in oc_group):
            candidate_occurrences.extend(oc_group)

    # collect all operands that show up
    ops = set(oc.operand.name for oc in candidate_occurrences)

    # Symbols directely inside an inverse always have to be factored.
    ops_must_factor = set()
    ops_may_factor = set()
    for op in ops:
        if any(oc.operand.name == op and oc.symbol
               for oc in candidate_occurrences):
            ops_must_factor.add(op)
        else:
            ops_may_factor.add(op)

    # sorting here removes randomness
    for ops_subset in powerset(sorted(ops_may_factor)):

        factor_ops = ops_must_factor.union(ops_subset)
        if not factor_ops or factor_ops in blocking_products:
            continue

        factorizations_candidates = []
        # sorting here removes randomness
        factor_ops_sorted = sorted(factor_ops)
        for op in factor_ops_sorted:
            factorizations_candidates.append(factorization_dict[op])

        # apply all factorizations
        for factorizations in itertools.product(*factorizations_candidates):
            facts_dict = dict(zip(factor_ops_sorted, factorizations))

            # collect matched kernels (avoiding duplicates)
            matched_kernels = []
            _already_seen = set()
            for matched_kernel in factorizations:
                if matched_kernel.id not in _already_seen:
                    matched_kernels.append(matched_kernel)
                    _already_seen.add(matched_kernel.id)

            # collect replacements
            replacements_per_equation = dict()

            for oc in candidate_occurrences:
                if oc.operand.name in factor_ops:
                    replacements_per_equation.setdefault(
                        oc.eqn_idx, []).append(
                            (oc.position,
                             facts_dict[oc.operand.name].replacement))

            # replace
            equations_list = list(equations.equations)

            for eqn_idx, replacements in replacements_per_equation.items():
                if replacements:
                    equations_list[eqn_idx] = matchpy.replace_many(
                        equations_list[eqn_idx], replacements)

            equations_copy = Equations(*equations_list)
            equations_copy = equations_copy.simplify()
            equations_copy.set_equivalent(equations)
            equations_copy = equations_copy.to_SOP().simplify()

            yield (equations_copy, matched_kernels)
Esempio n. 2
0
def find_CSEs(equations):
    """Finds and replaces common subexpressions in equations.

    Args:
        equations (Equations): Some equations.

    Yields:
        Equations: The input equation with eliminated common subexpressions.        
    """

    CSE_detector = CSEDetector()
    for eqn_idx, equation in enumerate(equations):
        for expr, positions, level in all_subexpressions(equation.rhs):
            # print(expr, positions)

            # for expressions of the form Inverse(expr), we don't want to add the
            # inverted variant because this will produce "fake" CSEs. The reason
            # is that expr will also be added.
            # same for transpose
            # TODO do we want to invert if expr is not square?
            inv = contains_inverse(expr) and not is_inverse(expr)
            trans = contains_transpose(expr) and not is_transpose(expr)
            if expr.has_property(Property.SYMMETRIC):
                trans = False
            if inv and trans:
                subexpr = Subexpression(expr, eqn_idx, positions, level, CSEType.inverse_transpose)
                CSE_detector.add_subexpression(subexpr)
            elif inv:
                subexpr = Subexpression(expr, eqn_idx, positions, level, CSEType.inverse)
                CSE_detector.add_subexpression(subexpr)
            elif trans:
                subexpr = Subexpression(expr, eqn_idx, positions, level, CSEType.transpose)
                CSE_detector.add_subexpression(subexpr)
            else:
                subexpr = Subexpression(expr, eqn_idx, positions, level, CSEType.none)
                CSE_detector.add_subexpression(subexpr)

    CSEs = list(CSE_detector.CSEs())
    CSEs.sort(key=sort_keyfunc, reverse=True)

    for CSE in CSEs:
        # print("CSEs", [(str(subexpr.expr), subexpr.eqn_idx) for subexpr in sorted(CSE)])

        CSE_as_dict = dict()
        for subexpr in CSE:
            CSE_as_dict.setdefault(CSE_detector.subexpr_to_CSE[subexpr.id], []).append(subexpr)
  
        insert_equations = []
        replacements_per_equation = dict()
        for CSE_id, subexprs in CSE_as_dict.items():
            subexprs.sort()
            min_eqn_idx = min(subexpr.eqn_idx for subexpr in subexprs)

            CSE_expr = subexprs[0].expr # this works because indentify_subexpression_types uses subexprs[0] as reference
            tmp = temporaries.create_tmp(CSE_expr, True)
            eqn = Equal(tmp, CSE_expr)

            insert_equations.append((min_eqn_idx, eqn))

            for subexpr, subexpr_type in zip(subexprs, indentify_subexpression_types(subexprs)):
                positions = list(subexpr.positions)
                for position in positions[1:]:
                    replacements_per_equation.setdefault(subexpr.eqn_idx, []).append(((1,) + position, []))

                tmp_expr = tmp
                if subexpr_type == CSEType.transpose:
                    tmp_expr = Transpose(tmp)
                elif subexpr_type == CSEType.inverse:
                    tmp_expr = Inverse(tmp)
                elif subexpr_type == CSEType.inverse_transpose:
                    tmp_expr = InverseTranspose(tmp)

                replacements_per_equation.setdefault(subexpr.eqn_idx, []).append(((1,) + positions[0], tmp_expr))

        equations_list = list(equations.equations)
        for eqn_idx, replacements in replacements_per_equation.items():
            equations_list[eqn_idx] = matchpy.replace_many(equations_list[eqn_idx], replacements)
            
        # Inserting new equation for extracted CSE.
        # It is inserted right before the first occurrence of the CSE.
        insert_equations.sort(reverse=True)
        for min_eqn_idx, eqn in insert_equations:
            equations_list.insert(min_eqn_idx, eqn)
        new_equations = Equations(*equations_list)
        new_equations = new_equations.to_normalform()

        yield new_equations