Exemple #1
0
def _(expr, estimate):
    if q_routine(expr):
        flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
        flops = sum(flops)
        if isinstance(expr, DefFunction):
            # Bypass user-defined or language-specific functions
            flops += 0
        elif estimate:
            flops += estimate_values['elementary']
        else:
            flops += 1
    else:
        flops = 0
    return flops, False
Exemple #2
0
def estimate_cost(exprs, estimate=False):
    """
    Estimate the operation count of an expression.

    Parameters
    ----------
    exprs : expr-like or list of expr-like
        One or more expressions for which the operation count is calculated.
    estimate : bool, optional
        Defaults to False; if True, the following rules are applied in order:
            * An operation involving only `Constant`'s counts as 1 ops.
            * Trascendental functions (e.g., cos, sin, ...) count as 50 ops.
            * Divisions (powers with a negative exponent) count as 25 ops.
            * Powers with integer exponent `n>0` count as n-1 ops (as if
              it were a chain of multiplications).
    """
    trascendentals_cost = {sin: 100, cos: 100, exp: 100, log: 100}
    pow_cost = 50
    div_cost = 25

    try:
        # Is it a plain symbol/array ?
        if exprs.is_AbstractFunction or exprs.is_AbstractSymbol:
            return 0
    except AttributeError:
        pass
    try:
        # Is it a dict ?
        exprs = exprs.values()
    except AttributeError:
        try:
            # Could still be a list of dicts
            exprs = flatten([i.values() for i in exprs])
        except (AttributeError, TypeError):
            pass
    try:
        # At this point it must be a list of SymPy objects
        # We don't use SymPy's count_ops because we do not count integer arithmetic
        # (e.g., array index functions such as i+1 in A[i+1])
        # Also, the routine below is *much* faster than count_ops
        exprs = [i.rhs if i.is_Equality else i for i in as_tuple(exprs)]
        operations = flatten(retrieve_xops(i) for i in exprs)
        flops = 0
        for op in operations:
            if op.is_Function:
                if estimate and q_routine(op):
                    terminals = retrieve_terminals(op, deep=True)
                    if all(i.function.is_const for i in terminals):
                        flops += 1
                    else:
                        flops += trascendentals_cost.get(op.__class__, 1)
                else:
                    flops += 1
            elif op.is_Pow:
                if estimate:
                    terminals = retrieve_terminals(op, deep=True)
                    if all(i.function.is_const for i in terminals):
                        flops += 1
                    elif op.exp.is_Number:
                        if op.exp < 0:
                            flops += div_cost
                        elif op.exp == 0:
                            flops += 0
                        elif op.exp.is_Integer:
                            # Natural pows a**b are estimated as b-1 Muls
                            flops += int(op.exp) - 1
                        else:
                            flops += pow_cost
                    else:
                        flops += pow_cost
                else:
                    flops += 1
            else:
                flops += len(
                    op.args) - (1 + sum(True for i in op.args if i.is_Integer))
        return flops
    except:
        warning("Cannot estimate cost of `%s`" % str(exprs))
        return 0