Beispiel #1
0
def rotate_to_source(linear: Dict[Tuple[str, int],
                                  ITeg], target_vars: List[TegVar],
                     source_vars: List[TegVar]) -> List[ITeg]:
    """Generates the set of expressions for the source variables in terms of the rotated targets.

    See Appendix A for details.
    """
    rotation = []
    num_vars = len(target_vars)
    exprs = [linear[(s_var.name, s_var.uid)] for s_var in source_vars]
    for source_index in range(num_vars):
        if source_index == 0:
            rotation.append(
                sum((Const(1) if i == 0 else Const(-1)) * exprs[i] *
                    target_vars[i] for i in range(num_vars)))
        elif source_index < len(linear):
            i = source_index
            inverse_rotation = sum(
                ((Const(1) if i == j else Const(0)) - (exprs[i] * exprs[j]) /
                 (1 + exprs[0])) * target_vars[j] for j in range(1, num_vars))
            rotation.append(inverse_rotation + exprs[i] * target_vars[0])
        else:
            raise ValueError(
                f'Requested source coordinate index: {source_index} is invalid.'
            )

    return rotation
Beispiel #2
0
def rotate_to_target(linear: Dict[Tuple[str, int], ITeg],
                     source_vars: List[TegVar]) -> List[ITeg]:
    """Generates the set of expressions for the rotated target variables.

    See Appendix A for details.
    """
    rotation = []
    num_vars = len(source_vars)
    exprs = [linear[(s_var.name, s_var.uid)] for s_var in source_vars]
    for target_index in range(num_vars):
        if target_index == 0:
            rotation.append(
                sum(exprs[i] * source_vars[i] for i in range(num_vars)))
        elif target_index < len(linear):
            i = target_index
            rotation_expr = sum(
                ((Const(1) if i == j else Const(0)) - (exprs[i] * exprs[j]) /
                 (1 + exprs[0])) * source_vars[j] for j in range(1, num_vars))
            rotation.append(-exprs[i] * source_vars[0] + rotation_expr)
        else:
            raise ValueError(
                f'Requested target coordinate index: {target_index} is out of bounds.'
            )

    return rotation
Beispiel #3
0
    def outer_fn(e, ctx):

        ctx['has_expr'] = any(ctx['has_exprs'])
        # Check if we need to handle other such cases.
        assert not (ctx['has_expr'] and isinstance(e, SmoothFunc)),\
               f'expr is contained in a non-linear function {type(e)}'
        if isinstance(e, Add):
            if ctx['has_expr']:
                ctx['expr'] = sum([
                    child
                    for child, has_expr in zip(ctx['exprs'], ctx['has_exprs'])
                    if has_expr
                ])
                return ctx['expr'], ctx
            else:
                ctx['expr'] = e
                return e, ctx
        elif isinstance(e, Tup):
            if ctx['has_expr']:
                ctx['expr'] = Tup(*[
                    ctx['exprs'][idx] if has_expr else Const(0)
                    for idx, has_expr in enumerate(ctx['has_exprs'])
                ])
                return ctx['expr'], ctx
        elif isinstance(e, IfElse):
            if ctx['has_expr']:
                ctx['expr'] = IfElse(
                    e.cond,
                    ctx['exprs'][1] if ctx['has_exprs'][1] else Const(0),
                    ctx['exprs'][2] if ctx['has_exprs'][2] else Const(0))
                return ctx, e
        elif isinstance(e, LetIn):
            if any(ctx['has_exprs'][1:]):
                # Let expressions contain exprs.
                new_exprs = [
                    let_var for let_var, has_expr in zip(
                        e.new_vars, ctx['has_exprs'][1:]) if has_expr
                ]
                # Recursively split the body with the new expressions.
                s_expr = split_exprs(new_exprs, ctx['let_body'])
                let_body = (s_expr if s_expr else Const(0)) +\
                           (e.expr if ctx['has_exprs'][0] else Const(0))
                try:
                    vs, es = zip(*[(v, e)
                                   for v, e in zip(e.new_vars, e.new_exprs)
                                   if v in let_body])
                    ctx['expr'] = LetIn(vs, es, let_body)
                except ValueError:
                    # No need for a let expr.
                    ctx['expr'] = let_body

                return ctx['expr'], ctx

        ctx['expr'] = e
        return ctx['expr'], ctx
Beispiel #4
0
def eliminate_bimaps(expr: ITeg):
    # find top_level bimap
    # check if bimap contains delta.
    # If yes: lift using split_instance() if bimap is not already linear in tree
    #         reduce using reparameterize()
    # If no: convert to let expression

    top_level_bimap = top_level_instance_of(expr,
                                            lambda a: isinstance(a, BiMap))
    if top_level_bimap is None:
        return expr

    top_level_delta_of_bimap = top_level_instance_of(
        top_level_bimap, lambda a: isinstance(a, Delta))
    if top_level_delta_of_bimap is None:
        let_expr = LetIn(top_level_bimap.targets, top_level_bimap.target_exprs,
                         top_level_bimap.expr)
        return eliminate_bimaps(
            substitute_instance(expr, top_level_bimap, let_expr))
    else:
        linear_expr = split_instance(top_level_bimap, expr)
        old_tree = substitute_instance(expr, top_level_bimap, Const(0))
        new_tree = tree_copy(reparameterize(top_level_bimap, linear_expr))
        e = old_tree + new_tree

        return eliminate_bimaps(e)
Beispiel #5
0
def check_single_linear_var(expr, not_ctx=set()):
    """Checks that expr contains a single variable with coefficient 1. """
    affine_list = extract_coefficients_from_affine(expr, {(var.name, var.uid)
                                                          for var in not_ctx})
    linear_list = remove_constant_coeff(affine_list)
    return (len(linear_list) == 1  # Single variable
            and Const(1) in linear_list.values())  # with a coefficient of 1
Beispiel #6
0
def combine_poly_sets(poly_lists: List[Dict[Tuple[Tuple[Var, int]], ITeg]],
                      op):
    # Combine polynomial sets. Assumes the list satisfies affine properties.
    combined_set = {}
    if op == operator.mul:
        # Cartesian product. Produce every variable combination.
        poly_products = product(
            *[poly_list.items() for poly_list in poly_lists])

        for poly_product in poly_products:
            combined_variable = [var_expr[0] for var_expr in poly_product]
            k = [var_expr[1] for var_expr in poly_product]
            combined_expr = reduce(operator.mul, k)

            # Reduce combined variables to primitive variables.
            primitive_variable = combine_variables(combined_variable)
            combined_set[primitive_variable] = combined_expr

    elif op == operator.add:
        for poly_list in poly_lists:
            for variable, expr in poly_list.items():
                combined_set[variable] = combined_set.get(variable,
                                                          Const(0)) + expr

    else:
        raise ValueError('Operation not supported')

    return combined_set
Beispiel #7
0
def get_poly_term(expr_list, multiplicities):
    new_var = ((var, multiplicity)
               for var, multiplicity in multiplicities.items()
               if var is not CONST_VAR and multiplicity != 0)
    new_var = tuple(sorted(new_var, key=lambda a: a[0].uid))
    new_var = new_var if new_var else ((CONST_VAR, 1), )
    return expr_list.get(new_var, Const(0))
Beispiel #8
0
def translate(in_vars: Tuple[TegVar], translate: Tuple[ITeg]):
    out_vars = [TegVar(f'{in_var.name}_t') for in_var in in_vars]
    return (partial(BiMap,
                    targets=out_vars,
                    target_exprs=[in_var + t for (in_var, t) in zip(in_vars,  translate)],
                    sources=in_vars,
                    source_exprs=[out_var - t for (out_var, t) in zip(out_vars, translate)],
                    inv_jacobian=Const(1),
                    target_upper_bounds=[in_var.ub() + t for (in_var, t) in zip(in_vars, translate)],
                    target_lower_bounds=[in_var.lb() + t for (in_var, t) in zip(in_vars, translate)]),
            out_vars)
Beispiel #9
0
def extract_coefficients_from_polynomial(
        expr: ITeg,
        not_ctx: Set[Tuple[str, int]]) -> Dict[Set[Tuple[Var, int]], ITeg]:
    if isinstance(expr, Mul):
        children_coeffs = [
            extract_coefficients_from_polynomial(child, not_ctx)
            for child in expr.children
        ]
        return combine_poly_sets(children_coeffs, op=operator.mul)
    elif isinstance(expr, Add):
        children_coeffs = [
            extract_coefficients_from_polynomial(child, not_ctx)
            for child in expr.children
        ]
        return combine_poly_sets(children_coeffs, op=operator.add)
    elif isinstance(expr, TegVar) and (expr.name, expr.uid) in not_ctx:
        return {((expr, 1), ): Const(1)}
    elif is_expr_parametric(expr, not_ctx):
        return {((CONST_VAR, 1), ): expr}
    else:
        return {((CONST_VAR, 1), ): Const(0)}
Beispiel #10
0
    def derivs_for_single_outval(
            expr: ITeg,
            single_outval: Const,
            i: Optional[int] = None,
            output_list: Optional[List[Var]] = None,
            args: Dict[str, Any] = None) -> Tuple[List[Var], ITeg]:

        partial_deriv_map = defaultdict(lambda: Const(0))

        # After deriv_transform, expr will have unbound infinitesimals
        for name_uid, e in reverse_deriv_transform(expr, single_outval, set(),
                                                   {}, args):
            partial_deriv_map[name_uid] += e

        # Introduce fresh variables for each partial derivative
        uids = [var_uid for var_name, var_uid in partial_deriv_map.keys()]
        new_vars = [
            Var(var_name) for var_name, var_uid in partial_deriv_map.keys()
        ]
        new_vals = [*partial_deriv_map.values()]

        if output_list is not None:
            # Return requested list of outputs.
            var_map = {
                uid: (var, val)
                for uid, var, val in zip(uids, new_vars, new_vals)
            }
            new_vars, new_vals = zip(*[
                var_map.get(var.uid, (Var(f'd{var.name}'), Const(0)))
                for var in output_list
            ])
        else:
            # Return a list sorted in the order the variables were defined.
            sorted_list = list(zip(uids, new_vars, new_vals))
            sorted_list.sort(key=lambda a: a[0])
            _, new_vars, new_vals = list(zip(*sorted_list))

        assert len(
            new_vals) > 0, 'There must be variables to compute derivatives. '
        return new_vars, (Tup(*new_vals) if len(new_vars) > 1 else new_vals[0])
Beispiel #11
0
    def outer_fn(e, ctx):
        if isinstance(e, Delta) and (ctx['search_expr'] is e):
            assert is_delta_normal(
                e
            ), f'Delta {e} is not in normal form. Call normalize_delta() first'
            if e.expr not in ctx['upper_tegvars']:
                return Const(0), ctx
            else:
                return Const(1), {
                    **ctx, 'eliminate_tegs': {
                        **ctx['eliminate_tegs'], e.expr: Const(0)
                    }
                }

        elif isinstance(e, Teg):
            if e.dvar in ctx['eliminate_tegs']:
                value = ctx['eliminate_tegs'][e.dvar]
                bounds_check = (e.lower < value) & (e.upper > value)
                return (LetIn([e.dvar], [value],
                              IfElse(bounds_check, e.body, Const(0))), ctx)

        return e, ctx
Beispiel #12
0
def extract_coefficients_from_affine(
        expr: ITeg, not_ctx: Set[Union[Var,
                                       Tuple]]) -> Dict[Tuple[str, int], ITeg]:
    """Canonicalizes an affine expression to a mapping from variables to coefficients with a constant term. """
    if isinstance(expr, Mul):
        children_coeffs = [
            extract_coefficients_from_affine(child, not_ctx)
            for child in expr.children
        ]
        return combine_affine_sets(children_coeffs, op=operator.mul)
    elif isinstance(expr, Add):
        children_coeffs = [
            extract_coefficients_from_affine(child, not_ctx)
            for child in expr.children
        ]
        return combine_affine_sets(children_coeffs, op=operator.add)
    elif isinstance(expr, TegVar) and (expr.name, expr.uid) in not_ctx:
        return {(expr.name, expr.uid): Const(1)}
    elif is_expr_parametric(expr, not_ctx):
        return {('__const__', -1): expr}
    else:
        return {('__const__', -1): Const(0)}
Beispiel #13
0
    def rewrite(delta, not_ctx=set()):
        """Define a change of varibles so that Delta(x + c) becomes Delta(y). """
        affine_list = extract_coefficients_from_affine(delta.expr,
                                                       {(var.name, var.uid)
                                                        for var in not_ctx})
        constant = affine_list.get(('__const__', -1), Const(0))
        only_var = [(name, uid) for name, uid in affine_list.keys()
                    if uid != -1]
        assert len(
            only_var
        ) == 1, f'Only one tegvar can be included in the affine expression. {only_var}'

        var_name, var_uid = only_var[0]
        source_var = TegVar(name=var_name, uid=var_uid)
        target_var = TegVar(name=f'{var_name}_')
        return BiMap(expr=Delta(target_var),
                     sources=[source_var],
                     source_exprs=[target_var - constant],
                     targets=[target_var],
                     target_exprs=[source_var + constant],
                     inv_jacobian=Const(1),
                     target_lower_bounds=[source_var.lower_bound() + constant],
                     target_upper_bounds=[source_var.upper_bound() + constant])
Beispiel #14
0
def transfer_bounds_general(expr: BiMap, source_lower: Dict[TegVar, ITeg],
                            source_upper: Dict[TegVar, ITeg]):
    """Implements a derivative-based pessimistic bounds computation for
    continuous monotonic maps. """
    lb_lets = {}
    ub_lets = {}

    for tegvar in source_lower:
        deriv_expr = fwd_deriv(expr, {tegvar: Const(1)})
        lb_lets[tegvar] = (IfElse(deriv_expr > 0, source_upper[tegvar],
                                  source_lower[tegvar]))
        ub_lets[tegvar] = (IfElse(deriv_expr > 0, source_lower[tegvar],
                                  source_upper[tegvar]))

    return LetIn(lb_lets.keys(), lb_lets.values(),
                 expr), LetIn(ub_lets.keys(), ub_lets.values(), expr)
Beispiel #15
0
    def rewrite(delta: Delta, not_ctx: Optional[Set] = set()) -> ITeg:
        """Rotates an affine discontinuity so that it's axis-aligned (e.g. ax + by + c -> z + d). """

        not_ctx = set() if not_ctx is None else not_ctx

        # Canonicalize affine expression into a map {var: coeff}
        raw_affine_set = extract_coefficients_from_affine(
            delta.expr, {(var.name, var.uid)
                         for var in not_ctx})

        # Introduce a constant term if there isn't one
        if ('__const__', -1) not in raw_affine_set:
            raw_affine_set[('__const__', -1)] = Const(0)

        # Extract source variables (in order)
        source_vars = [
            TegVar(name=name, uid=uid)
            for name, uid in var_list(remove_constant_coeff(raw_affine_set))
        ]

        # Create rotated (target) variables
        target_vars = [TegVar(name=f'{var.name}_') for var in source_vars]

        # TODO: Currently, do not handle degeneracy at -1
        affine_set, flip_condition = negate_degenerate_coeffs(
            raw_affine_set, source_vars)
        linear_set = remove_constant_coeff(affine_set)
        normalized_set, normalization_var, normalization_expr = normalize_linear(
            linear_set)

        dvar = target_vars[0]
        expr_for_dvar = -constant_coeff(affine_set) * normalization_var

        source_exprs = rotate_to_source(normalized_set, target_vars,
                                        source_vars)
        target_exprs = rotate_to_target(normalized_set, source_vars)
        lower_bounds, upper_bounds = bounds_of(normalized_set, source_vars)

        return LetIn([normalization_var], [normalization_expr],
                     BiMap(expr=Delta(dvar - expr_for_dvar),
                           sources=source_vars,
                           source_exprs=source_exprs,
                           targets=target_vars,
                           target_exprs=target_exprs,
                           inv_jacobian=normalization_var,
                           target_lower_bounds=lower_bounds,
                           target_upper_bounds=upper_bounds))
Beispiel #16
0
def rotate_2d(x, y, theta):
    x_ = TegVar('x_')
    y_ = TegVar('y_')
    return (partial(BiMap,
                    targets=[x_, y_],
                    target_exprs=[x * Cos(theta) + y * Sin(theta), -x * Sin(theta) + y * Cos(theta)],
                    sources=[x, y],
                    source_exprs=[x_ * Cos(theta) - y_ * Sin(theta), x_ * Sin(theta) + y_ * Cos(theta)],
                    inv_jacobian=Const(1),
                    target_lower_bounds=[Cos(theta) * IfElse(Cos(theta) > 0, x.lb(), x.ub()) +
                                         Sin(theta) * IfElse(Sin(theta) > 0, y.lb(), y.ub()),
                                         -Sin(theta) * IfElse(Sin(theta) > 0, x.ub(), x.lb()) +
                                         Cos(theta) * IfElse(Cos(theta) > 0, y.lb(), y.ub())],
                    target_upper_bounds=[Cos(theta) * IfElse(Cos(theta) > 0, x.ub(), x.lb()) +
                                         Sin(theta) * IfElse(Sin(theta) > 0, y.ub(), y.lb()),
                                         -Sin(theta) * IfElse(Sin(theta) > 0, x.lb(), x.ub()) +
                                         Cos(theta) * IfElse(Cos(theta) > 0, y.ub(), y.lb())]),
            [x_, y_])
Beispiel #17
0
def bounds_of(linear: Dict[Tuple[str, int], ITeg],
              source_vars: List[TegVar]) -> List[ITeg]:
    """Generates the bounds of integration after rotation (i.e., it's the bounds transfer function). """
    lower_bounds, upper_bounds = [], []
    num_vars = len(source_vars)
    exprs = [linear[(s_var.name, s_var.uid)] for s_var in source_vars]
    for target_index in range(num_vars):
        if target_index == 0:
            lower = sum(exprs[i] *
                        IfElse(exprs[i] > 0, source_vars[i].lower_bound(),
                               source_vars[i].upper_bound())
                        for i in range(num_vars))
            upper = sum(exprs[i] *
                        IfElse(exprs[i] > 0, source_vars[i].upper_bound(),
                               source_vars[i].lower_bound())
                        for i in range(num_vars))
        elif target_index < len(linear):

            def coeff(u, v):
                if v == 0:
                    return -exprs[u]
                else:
                    return ((Const(1) if u == v else Const(0)) -
                            (exprs[u] * exprs[v]) / (Const(1) + exprs[0]))

            i = target_index
            lower = upper = Const(0)
            for j in range(num_vars):
                placeholder_lb = source_vars[j].lower_bound()
                placeholder_ub = source_vars[j].upper_bound()
                lower += coeff(i, j) * IfElse(
                    coeff(i, j) > 0, placeholder_lb, placeholder_ub)
                upper += coeff(i, j) * IfElse(
                    coeff(i, j) > 0, placeholder_ub, placeholder_lb)
        else:
            raise ValueError(
                f'Requested target coordinate index: {target_index} is out of bounds.'
            )
        lower_bounds.append(lower)
        upper_bounds.append(upper)

    return lower_bounds, upper_bounds
Beispiel #18
0
    def outer_fn(e, ctx):
        if isinstance(e, Delta):
            # print(ctx['upper_depvars'])
            depvars = list(tegvar for tegvar in (ctx['upper_depvars'] -
                                                 ctx['upper_tegvars'])
                           if tegvar in e)
            assert not depvars,\
                   f'Delta expression {e} is not explicitly affine: ({depvars}) '\
                   f'is/are dependent on one or more of {ctx["upper_tegvars"]} '\
                   f'through one-way let expressions. Use bijective maps (BiMap) instead'
            if (not any([
                    k in ctx['upper_tegvars'] for k in ctx['lower_tegvars']
            ])) or (not ctx['lower_tegvars']):
                return Const(0), ctx
            else:
                if not is_delta_normal(e):
                    can_rewrites = [
                        handler.can_rewrite(e, set(ctx['upper_tegvars']))
                        for handler in HANDLERS
                    ]
                    assert any(
                        can_rewrites
                    ), f'Cannot find any handler for delta expression {e}'

                    handler = HANDLERS[can_rewrites.index(True)]
                    e = handler.rewrite(e, set(ctx['upper_tegvars']))
                    e = normalize_deltas(e)  # Normalize further if necessary

                return e, ctx

        elif isinstance(e, BiMap):
            return e, {
                **ctx, 'lower_tegvars': ctx['lower_tegvars'] - set(e.targets)
            }

        elif isinstance(e, TegVar):
            return e, {**ctx, 'lower_tegvars': ctx['lower_tegvars'] | {e}}

        return e, ctx
Beispiel #19
0
def simplify(expr: ITeg) -> ITeg:

    if isinstance(expr, Var):
        return expr

    elif isinstance(expr, Add):
        expr1, expr2 = expr.children
        simple1, simple2 = simplify(expr1), simplify(expr2)
        if isinstance(simple1, Const) and simple1.value == 0:
            return simple2
        if isinstance(simple2, Const) and simple2.value == 0:
            return simple1
        if isinstance(simple1, Const) and isinstance(simple2, Const):
            return Const(evaluate(simple1 + simple2))

        # Associative reordering.
        if isinstance(simple1,
                      (Add, Const)) and isinstance(simple2, (Add, Const)):
            nodes1 = [
                simple1,
            ] if isinstance(simple1, Const) else simple1.children
            nodes2 = [
                simple2,
            ] if isinstance(simple2, Const) else simple2.children
            all_nodes = nodes1 + nodes2
            assert 2 <= len(
                all_nodes
            ) <= 4, 'Unexpected number of nodes in Add-associative tree'

            const_nodes = [
                node for node in all_nodes if isinstance(node, Const)
            ]
            other_nodes = [
                node for node in all_nodes if not isinstance(node, Const)
            ]

            # No const nodes -> Reordering is pointless.
            if len(other_nodes) == len(all_nodes):
                return simple1 + simple2

            # Compress const nodes.
            const_node = Const(evaluate(reduce(operator.add, const_nodes)))

            # Re-order to front.
            if const_node == Const(0):
                simplified_nodes = other_nodes
            else:
                simplified_nodes = other_nodes + [const_node]

            # Build tree in reverse (so const node is at top level)
            return reduce(operator.add, simplified_nodes)

        if isinstance(simple1, LetIn) and isinstance(simple2, LetIn):
            if simple1.new_vars == simple2.new_vars and simple1.new_exprs == simple2.new_exprs:
                return LetIn(new_vars=simple1.new_vars,
                             new_exprs=simple1.new_exprs,
                             expr=simplify(simple1.expr + simple2.expr))
            else:
                return simple1 + simple2

        if isinstance(simple1, Teg) and isinstance(simple2, Teg):
            if (simple1.dvar == simple2.dvar and simple1.lower == simple2.lower
                    and simple1.upper == simple2.upper):
                return simplify(
                    Teg(simple1.lower, simple1.upper,
                        simplify(simple1.body + simple2.body), simple1.dvar))

            else:
                return simple1 + simple2

        if isinstance(simple1, IfElse) and isinstance(simple2, IfElse):
            if simple1.cond == simple2.cond:
                return IfElse(simple1.cond,
                              simplify(simple1.if_body + simple2.if_body),
                              simplify(simple1.else_body + simple2.else_body))
            else:
                return simple1 + simple2

        if isinstance(simple1, Mul) and isinstance(simple2, Mul):
            # Distribution.
            exprLL, exprLR = simple1.children
            exprRL, exprRR = simple2.children

            if exprLL == exprRR:
                return simplify(exprLL * (simplify(exprLR + exprRL)))
            if exprLL == exprRL:
                return simplify(exprLL * (simplify(exprLR + exprRR)))
            if exprLR == exprRL:
                return simplify(exprLR * (simplify(exprLL + exprRR)))
            if exprLR == exprRR:
                return simplify(exprLR * (simplify(exprLL + exprRL)))

        return simple1 + simple2

    elif isinstance(expr, Mul):
        expr1, expr2 = expr.children
        simple1, simple2 = simplify(expr1), simplify(expr2)

        # 0-elimination
        if ((isinstance(simple1, Const) and simple1.value == 0)
                or (isinstance(simple2, Const) and hasattr(simple2, 'value')
                    and simple2.value == 0)):
            return Const(0)

        # Multiplicative inverse.
        if isinstance(simple1, Const) and simple1.value == 1.0:
            return simple2
        if isinstance(simple2, Const) and simple2.value == 1.0:
            return simple1

        # Local constant compression.
        if isinstance(simple1, Const) and isinstance(simple2, Const):
            return Const(evaluate(simple1 * simple2))

        # Associative reordering.
        if isinstance(simple1,
                      (Mul, Const)) and isinstance(simple2, (Mul, Const)):
            nodes1 = [simple1] if isinstance(simple1,
                                             Const) else simple1.children
            nodes2 = [simple2] if isinstance(simple2,
                                             Const) else simple2.children
            all_nodes = nodes1 + nodes2
            assert 2 <= len(
                all_nodes
            ) <= 4, 'Unexpected number of nodes in Mul-associative tree'

            const_nodes = [
                node for node in all_nodes if isinstance(node, Const)
            ]
            other_nodes = [
                node for node in all_nodes if not isinstance(node, Const)
            ]

            # No const nodes -> Reordering is pointless.
            if len(other_nodes) == len(all_nodes):
                return simple1 * simple2

            # Compress const nodes.
            const_node = Const(evaluate(reduce(operator.mul, const_nodes)))

            # Re-order to front.
            if not (const_node == Const(1)):
                simplified_nodes = other_nodes + [const_node]
            else:
                simplified_nodes = other_nodes

            # Build tree in reverse (so const node is at top level)
            return reduce(operator.mul, simplified_nodes)

        return simple1 * simple2

    elif isinstance(expr, Invert):
        simple = simplify(expr.child)
        if isinstance(simple, Const):
            return Const(evaluate(Invert(simple)))
        return Invert(simple)

    elif isinstance(expr, SmoothFunc):
        simple = simplify(expr.expr)
        if isinstance(simple, Const):
            return Const(evaluate(type(expr)(simple)))
        return type(expr)(simplify(expr.expr))

    elif isinstance(expr, IfElse):
        cond, if_body, else_body = simplify(expr.cond), simplify(
            expr.if_body), simplify(expr.else_body)
        if (isinstance(if_body, Const) and isinstance(else_body, Const)
                and if_body.value == 0 and else_body.value == 0):
            return if_body

        if cond == true:
            return if_body

        if cond == false:
            return else_body

        return IfElse(cond, if_body, else_body)

    elif isinstance(expr, Teg):
        body = simplify(expr.body)
        if isinstance(body, Const) and hasattr(body,
                                               'value') and body.value == 0:
            return Const(0)
        return Teg(simplify(expr.lower), simplify(expr.upper), body, expr.dvar)

    elif isinstance(expr, Tup):
        return Tup(*(simplify(child) for child in expr))

    elif isinstance(expr, LetIn):
        simplified_exprs = Tup(*(simplify(e) for e in expr.new_exprs))
        child_expr = simplify(expr.expr)
        vars_list = expr.new_vars

        for s_var, s_expr in zip(vars_list, simplified_exprs):
            if isinstance(s_expr, Const):
                child_expr = substitute(child_expr, s_var, s_expr)

        non_const_bindings = [
            (s_var, s_expr)
            for s_var, s_expr in zip(vars_list, simplified_exprs)
            if not isinstance(s_expr, Const)
        ]

        child_expr = simplify(child_expr)
        if non_const_bindings:
            non_const_vars, non_const_exprs = zip(*list(non_const_bindings))
            return (LetIn(non_const_vars, non_const_exprs, child_expr)
                    if not isinstance(child_expr, Const) else child_expr)
        else:
            return child_expr

    elif isinstance(expr, BiMap):
        simplified_target_exprs = list(simplify(e) for e in expr.target_exprs)
        simplified_source_exprs = list(simplify(e) for e in expr.source_exprs)

        simplified_ubs = list(simplify(e) for e in expr.target_upper_bounds)
        simplified_lbs = list(simplify(e) for e in expr.target_lower_bounds)

        child_expr = simplify(expr.expr)

        return BiMap(expr=child_expr,
                     targets=expr.targets,
                     target_exprs=simplified_target_exprs,
                     sources=expr.sources,
                     source_exprs=simplified_source_exprs,
                     inv_jacobian=simplify(expr.inv_jacobian),
                     target_lower_bounds=simplified_lbs,
                     target_upper_bounds=simplified_ubs)

    elif isinstance(expr, Delta):
        return Delta(simplify(expr.expr))

    elif {'FwdDeriv', 'RevDeriv'} & {t.__name__ for t in type(expr).__mro__}:
        return simplify(expr.__getattribute__('deriv_expr'))

    elif isinstance(expr, Bool):
        left_expr, right_expr = simplify(expr.left_expr), simplify(
            expr.right_expr)
        if isinstance(left_expr, Const) and isinstance(right_expr, Const):
            return false if evaluate(Bool(left_expr,
                                          right_expr)) == 0.0 else true
        return Bool(left_expr, right_expr)

    elif isinstance(expr, And):
        left_expr, right_expr = simplify(expr.left_expr), simplify(
            expr.right_expr)
        if left_expr == true:
            return right_expr
        if right_expr == true:
            return left_expr
        if left_expr == false or right_expr == false:
            return false
        return And(left_expr, right_expr)

    elif isinstance(expr, Or):
        left_expr, right_expr = simplify(expr.left_expr), simplify(
            expr.right_expr)
        if left_expr == false:
            return right_expr
        if right_expr == false:
            return left_expr
        if left_expr == true or right_expr == true:
            return true
        return Or(left_expr, right_expr)

    else:
        raise ValueError(
            f'The type of the expr "{type(expr)}" does not have a supported simplify rule'
        )
Beispiel #20
0
    t3, t4 = Var('t3'), Var('t4')

    scale_map, (x_s, y_s) = scale([x, y], [t1, t2])
    translate_map, (x_st, y_st) = translate([x_s, y_s], [t3, t4])

    # Area of a unit circle.
    bindings = {t1: 1, t2: 1, t3: 0, t4: 0, t: 0.25}
    # Derivative of threshold only.
    integral = Teg(
                x_lb, x_ub,
                Teg(y_lb, y_ub,
                    scale_map(translate_map(IfElse(x_st * y_st > t, 1, 0))), y
                    ), x
                )

    d_vars, dt_exprs = reverse_deriv(integral, Tup(Const(1)), output_list=[t, t1, t2, t3, t4])

    integral = reduce_to_base(integral)
    image = render_image(integral,
                         variables=((x_lb, x_ub), (y_lb, y_ub)),
                         bindings=bindings,
                         bounds=((-1, 1), (-1, 1)),
                         res=(args.res_x, args.res_y),
                         )
    save_image(np.abs(image), filename=f'{args.testname}.png')

    for d_var, dt_expr in zip(d_vars, dt_exprs):
        image = render_image(reduce_to_base(dt_expr),
                             variables=((x_lb, x_ub), (y_lb, y_ub)),
                             bindings=bindings,
                             bounds=((-1, 1), (-1, 1)),
Beispiel #21
0
def eliminate_deltas(expr: ITeg):
    # eliminate deltas through let expressions
    # remove the corresponding integral.
    # (error if corresponding integral does not exist)
    def inner_fn(e, ctx):
        if isinstance(e, Teg):
            return e, {
                'is_expr': ctx['search_expr'] is e,
                'upper_tegvars': ctx['upper_tegvars'] | {e.dvar},
                'search_expr': ctx['search_expr']
            }
        return e, {
            'is_expr': False,
            'upper_tegvars': ctx['upper_tegvars'],
            'search_expr': ctx['search_expr']
        }

    def outer_fn(e, ctx):
        if isinstance(e, Delta) and (ctx['search_expr'] is e):
            assert is_delta_normal(
                e
            ), f'Delta {e} is not in normal form. Call normalize_delta() first'
            if e.expr not in ctx['upper_tegvars']:
                return Const(0), ctx
            else:
                return Const(1), {
                    **ctx, 'eliminate_tegs': {
                        **ctx['eliminate_tegs'], e.expr: Const(0)
                    }
                }

        elif isinstance(e, Teg):
            if e.dvar in ctx['eliminate_tegs']:
                value = ctx['eliminate_tegs'][e.dvar]
                bounds_check = (e.lower < value) & (e.upper > value)
                return (LetIn([e.dvar], [value],
                              IfElse(bounds_check, e.body, Const(0))), ctx)

        return e, ctx

    def context_combine(contexts, ctx):
        return {
            'lower_tegvars':
            reduce(lambda a, b: a | b,
                   [ctx['lower_tegvars'] for ctx in contexts], set()),
            'upper_tegvars':
            ctx['upper_tegvars'],
            'eliminate_tegs':
            reduce(lambda a, b: {
                **a,
                **b
            }, [ctx['eliminate_tegs'] for ctx in contexts], {}),
            'search_expr':
            ctx['search_expr']
        }

    def eliminate_delta(delta, t_expr):
        return base_pass(t_expr, {
            'upper_tegvars': set(),
            'search_expr': delta
        }, inner_fn, outer_fn, context_combine)[0]

    top_level_delta = top_level_instance_of(expr,
                                            lambda a: isinstance(a, Delta))
    if top_level_delta is None:
        return expr
    else:
        linear_expr = split_instance(top_level_delta, expr)
        old_tree = substitute_instance(expr, top_level_delta, Const(0))
        new_tree = tree_copy(eliminate_delta(top_level_delta, linear_expr))
        return eliminate_deltas(old_tree + new_tree)
Beispiel #22
0
def reverse_deriv_transform(
        expr: ITeg, out_deriv_vals: Tuple, not_ctx: Set[Tuple[str, int]],
        deps: Dict[TegVar, Set[Var]],
        args: Dict[str, Any]) -> Iterable[Tuple[Tuple[str, int], ITeg]]:

    if isinstance(expr, TegVar):
        if (((expr.name, expr.uid) not in not_ctx)
                or {(v.name, v.uid)
                    for v in extend_dependencies({expr}, deps)} - not_ctx):
            yield ((f'd{expr.name}', expr.uid), out_deriv_vals)

    elif isinstance(expr, (Const, Delta)):
        pass

    elif isinstance(expr, Var):
        if (expr.name, expr.uid) not in not_ctx:
            yield ((f'd{expr.name}', expr.uid), out_deriv_vals)

    elif isinstance(expr, Add):
        left, right = expr.children
        # yield from reverse_deriv_transform(left, out_deriv_vals, not_ctx, teg_list)
        # yield from reverse_deriv_transform(right, out_deriv_vals, not_ctx, teg_list)
        left_list = list(
            reverse_deriv_transform(left, Const(1), not_ctx, deps, args))
        right_list = list(
            reverse_deriv_transform(right, Const(1), not_ctx, deps, args))
        yield from merge(left_list, right_list, out_deriv_vals)

    elif isinstance(expr, Mul):
        left, right = expr.children
        # yield from reverse_deriv_transform(left, out_deriv_vals * right, not_ctx, deps)
        # yield from reverse_deriv_transform(right, out_deriv_vals * left, not_ctx, deps)
        left_list = list(
            reverse_deriv_transform(left, right, not_ctx, deps, args))
        right_list = list(
            reverse_deriv_transform(right, left, not_ctx, deps, args))
        yield from merge(left_list, right_list, out_deriv_vals)

    elif isinstance(expr, Invert):
        child = expr.child
        yield from reverse_deriv_transform(child,
                                           -out_deriv_vals * expr * expr,
                                           not_ctx, deps, args)

    elif isinstance(expr, SmoothFunc):
        child = expr.expr
        yield from reverse_deriv_transform(
            child, expr.rev_deriv(out_deriv_expr=out_deriv_vals), not_ctx,
            deps, args)

    elif isinstance(expr, IfElse):
        derivs_if = reverse_deriv_transform(expr.if_body, Const(1), not_ctx,
                                            deps, args)
        derivs_else = reverse_deriv_transform(expr.else_body, Const(1),
                                              not_ctx, deps, args)
        yield from ((name_uid,
                     out_deriv_vals * IfElse(expr.cond, deriv_if, Const(0)))
                    for name_uid, deriv_if in derivs_if)
        yield from ((name_uid,
                     out_deriv_vals * IfElse(expr.cond, Const(0), deriv_else))
                    for name_uid, deriv_else in derivs_else)

        if not args.get('ignore_deltas', False):
            for boolean in primitive_booleans_in(expr.cond, not_ctx, deps):
                jump = substitute(expr, boolean, true) - substitute(
                    expr, boolean, false)
                delta_expr = boolean.right_expr - boolean.left_expr
                derivs_delta_expr = reverse_deriv_transform(
                    delta_expr, Const(1), not_ctx, deps, args)
                yield from (
                    (name_uid, out_deriv_vals * deriv_delta_expr * jump *
                     Delta(delta_expr))
                    for name_uid, deriv_delta_expr in derivs_delta_expr)

    elif isinstance(expr, Teg):
        not_ctx.discard((expr.dvar.name, expr.dvar.uid))

        # Apply Leibniz rule directly for moving boundaries
        if not args.get('ignore_bounds', False):
            lower_derivs = reverse_deriv_transform(expr.lower, out_deriv_vals,
                                                   not_ctx, deps, args)
            upper_derivs = reverse_deriv_transform(expr.upper, out_deriv_vals,
                                                   not_ctx, deps, args)
            yield from ((name_uid, upper_deriv *
                         substitute(expr.body, expr.dvar, expr.upper))
                        for name_uid, upper_deriv in upper_derivs)
            yield from ((name_uid, -lower_deriv *
                         substitute(expr.body, expr.dvar, expr.lower))
                        for name_uid, lower_deriv in lower_derivs)

        not_ctx.add((expr.dvar.name, expr.dvar.uid))

        deriv_body_traces = reverse_deriv_transform(expr.body, Const(1),
                                                    not_ctx, deps, args)

        yield from ((name_uid, out_deriv_vals *
                     Teg(expr.lower, expr.upper, deriv_body, expr.dvar))
                    for name_uid, deriv_body in deriv_body_traces)

    elif isinstance(expr, Tup):
        yield [
            reverse_deriv_transform(child, out_deriv_vals, not_ctx, deps, args)
            for child in expr
        ]

    elif isinstance(expr, LetIn):
        # Include derivatives of each expression to the let body
        dnew_vars, body_derivs = set(), {}
        for var, e in zip(expr.new_vars, expr.new_exprs):
            # print(not_ctx)
            # print(var, e)
            if any(
                    Var(name=ctx_name, uid=ctx_uid) in e
                    for ctx_name, ctx_uid in not_ctx):
                # Add dependent variables.
                assert isinstance(var, TegVar), f'{var} is dependent on TegVar(s):'\
                                                f'({[ctx_var for ctx_var in not_ctx if ctx_var in e]}).'\
                                                f'{var} must also be declared as a TegVar and not a Var'
                # print(not_ctx)
                not_ctx = not_ctx | {(var.name, var.uid)}

            # print(var)
            if var not in expr.expr:
                # print('Not in expression')
                continue
            # print('In expression')
            dname = f'd{var.name}'
            dnew_vars.add((dname, var.uid))
            body_derivs[(dname, var.uid)] = list(
                reverse_deriv_transform(e, Const(1), not_ctx, deps, args))

        # Thread through derivatives of each subexpression
        for (name, uid), dname_expr in reverse_deriv_transform(
                expr.expr, out_deriv_vals, not_ctx, deps, args):
            dvar_with_ctx = LetIn(expr.new_vars, expr.new_exprs, dname_expr)
            if (name, uid) in dnew_vars:
                yield from ((n, d * dvar_with_ctx)
                            for n, d in body_derivs[(name, uid)])
            else:
                yield ((name, uid), dvar_with_ctx)

    elif isinstance(expr, BiMap):
        # Include derivatives of each expression to the let body
        dnew_vars, body_derivs = set(), {}
        new_deps = {}
        for var, e in zip(expr.targets, expr.target_exprs):
            if any(
                    Var(name=ctx_name, uid=ctx_uid) in e
                    for ctx_name, ctx_uid in not_ctx):
                # Add dependent variables.
                assert isinstance(var, TegVar), f'{var} is dependent on TegVar(s):'\
                                                f'({[ctx_var for ctx_var in not_ctx if ctx_var in e]}).'\
                                                f'{var} must also be declared as a TegVar and not a Var'
                not_ctx = not_ctx | {(var.name, var.uid)}
            if var in expr.expr:
                new_deps[var] = extract_vars(e)
                dname = f'd{var.name}'
                dnew_vars.add((dname, var.uid))
                body_derivs[(dname, var.uid)] = list(
                    reverse_deriv_transform(e, Const(1), not_ctx, deps, args))

        deps = {**deps, **new_deps}
        # Thread through derivatives of each subexpression
        for (name, uid), dname_expr in reverse_deriv_transform(
                expr.expr, out_deriv_vals, not_ctx, deps, args):
            dvar_with_ctx = BiMap(dname_expr,
                                  expr.targets,
                                  expr.target_exprs,
                                  expr.sources,
                                  expr.source_exprs,
                                  inv_jacobian=expr.inv_jacobian,
                                  target_lower_bounds=expr.target_lower_bounds,
                                  target_upper_bounds=expr.target_upper_bounds)
            if (name, uid) in dnew_vars:
                yield from ((n, d * dvar_with_ctx)
                            for n, d in body_derivs[(name, uid)])
            else:
                yield ((name, uid), dvar_with_ctx)

    else:
        raise ValueError(
            f'The type of the expr "{type(expr)}" does not have a supported derivative.'
        )
Beispiel #23
0
def teg_smoothstep(x):
    return IfElse(x > 0, IfElse(x < 1, 3 * Sqr(x) - 2 * Sqr(x) * x, Const(1)), Const(0))
Beispiel #24
0
def reverse_deriv(expr: ITeg,
                  out_deriv_vals: Tup = None,
                  output_list: Optional[List[Var]] = None,
                  args: Dict[str, Any] = None) -> ITeg:
    """Computes the derivative of a given expression.

    Args:
        expr: The expression to compute the total derivative of.
        out_deriv_vals: A mapping from variable names to the values of corresponding infinitesimals.
        args: Additional mappings for specifying alternative behavior such as 'ignore_deltas' and 'ignore_bounds'.

    Returns:
        ITeg: The reverse derivative expression in the extended language.
    """

    if out_deriv_vals is None:
        out_deriv_vals = Tup(Const(1))

    if args is None:
        args = {}

    def derivs_for_single_outval(
            expr: ITeg,
            single_outval: Const,
            i: Optional[int] = None,
            output_list: Optional[List[Var]] = None,
            args: Dict[str, Any] = None) -> Tuple[List[Var], ITeg]:

        partial_deriv_map = defaultdict(lambda: Const(0))

        # After deriv_transform, expr will have unbound infinitesimals
        for name_uid, e in reverse_deriv_transform(expr, single_outval, set(),
                                                   {}, args):
            partial_deriv_map[name_uid] += e

        # Introduce fresh variables for each partial derivative
        uids = [var_uid for var_name, var_uid in partial_deriv_map.keys()]
        new_vars = [
            Var(var_name) for var_name, var_uid in partial_deriv_map.keys()
        ]
        new_vals = [*partial_deriv_map.values()]

        if output_list is not None:
            # Return requested list of outputs.
            var_map = {
                uid: (var, val)
                for uid, var, val in zip(uids, new_vars, new_vals)
            }
            new_vars, new_vals = zip(*[
                var_map.get(var.uid, (Var(f'd{var.name}'), Const(0)))
                for var in output_list
            ])
        else:
            # Return a list sorted in the order the variables were defined.
            sorted_list = list(zip(uids, new_vars, new_vals))
            sorted_list.sort(key=lambda a: a[0])
            _, new_vars, new_vals = list(zip(*sorted_list))

        assert len(
            new_vals) > 0, 'There must be variables to compute derivatives. '
        return new_vars, (Tup(*new_vals) if len(new_vars) > 1 else new_vals[0])

    if len(out_deriv_vals) == 1:
        single_outval = out_deriv_vals.children[0]
        derivs = derivs_for_single_outval(expr,
                                          single_outval,
                                          0,
                                          output_list=output_list,
                                          args=args)
    else:
        assert len(out_deriv_vals) == len(expr), \
            f'Expected out_deriv to have "{len(expr)}" values, but got "{len(out_deriv_vals)}" values.'

        derivs = (
            derivs_for_single_outval(e,
                                     single_outval,
                                     i,
                                     output_list=output_list,
                                     args=args)
            for i, (e, single_outval) in enumerate(zip(expr, out_deriv_vals)))
        derivs = Tup(*derivs)
    return derivs
Beispiel #25
0
from teg.lang.extended import (BiMap)

from teg.math import (Sin, Cos, Sqrt, ATan2, Sqr)

import numpy as np


def teg_max(a, b):
    return IfElse(a > b, a, b)


def teg_min(a, b):
    return IfElse(a > b, b, a)


TEG_NEGATIVE_PI = Const(-np.pi)
TEG_PI = Const(np.pi)
TEG_2_PI = Const(2 * np.pi)


def polar_2d_map(expr, x, y, r):
    """
    Create a polar 2D map with x=0, y=0 as center and negative y axis as 0 & 2PI
    """
    theta = TegVar('theta')

    distance_to_origin = Sqrt(
        Sqr((y.lb() + y.ub()) / 2) + Sqr((x.lb() + x.ub()) / 2))
    box_radius = Sqrt(Sqr((y.ub() - y.lb()) / 2) + Sqr((x.ub() - x.lb()) / 2))

    # Manual interval arithmetic for conservative polar bounds.
Beispiel #26
0
def constant_coeff(affine: Dict[Tuple[str, int], ITeg]):
    """Extract the constant coefficient if it exists, otherwise, return 0. """
    return affine[('__const__', -1)] if ('__const__',
                                         -1) in affine else Const(0)
Beispiel #27
0
 def coeff(u, v):
     if v == 0:
         return -exprs[u]
     else:
         return ((Const(1) if u == v else Const(0)) -
                 (exprs[u] * exprs[v]) / (Const(1) + exprs[0]))
Beispiel #28
0
def fwd_deriv_transform(
    expr: ITeg, ctx: Dict[Tuple[str, int], ITeg],
    not_ctx: Set[Tuple[str, int]], deps: Dict[TegVar, Set[Var]]
) -> Tuple[ITeg, Dict[Tuple[str, int], str], Set[Tuple[str, int]]]:
    """Compute the source-to-source foward derivative of the given expression."""
    if isinstance(expr, TegVar):
        if (((expr.name, expr.uid) not in not_ctx
             or {(v.name, v.uid)
                 for v in extend_dependencies({expr}, deps)} - not_ctx)
                and (expr.name, expr.uid) in ctx):
            expr = ctx[(expr.name, expr.uid)]
        else:
            expr = Const(0)

    elif isinstance(expr, (Const, Placeholder, Delta)):
        expr = Const(0)

    elif isinstance(expr, Var):
        if (expr.name, expr.uid) not in not_ctx and (expr.name,
                                                     expr.uid) in ctx:
            expr = ctx[(expr.name, expr.uid)]
        else:
            expr = Const(0)

    elif isinstance(expr, SmoothFunc):
        in_deriv_expr, ctx, not_ctx, deps = fwd_deriv_transform(
            expr.expr, ctx, not_ctx, deps)
        deriv_expr = expr.fwd_deriv(in_deriv_expr=in_deriv_expr)
        expr = deriv_expr

    elif isinstance(expr, Add):
        sum_of_derivs = Const(0)
        for child in expr.children:
            deriv_child, ctx, not_ctx, deps = fwd_deriv_transform(
                child, ctx, not_ctx, deps)
            sum_of_derivs += deriv_child

        expr = sum_of_derivs

    elif isinstance(expr, Mul):
        # NOTE: Consider n-ary multiplication.
        assert len(
            expr.children
        ) == 2, 'fwd_deriv only supports binary multiplication not n-ary.'
        expr1, expr2 = [child for child in expr.children]

        (deriv_expr1, ctx1, not_ctx1,
         _) = fwd_deriv_transform(expr1, ctx, not_ctx, deps)
        (deriv_expr2, ctx2, not_ctx2,
         _) = fwd_deriv_transform(expr2, ctx, not_ctx, deps)

        expr = expr1 * deriv_expr2 + expr2 * deriv_expr1
        ctx = {**ctx1, **ctx2}
        not_ctx = not_ctx1 | not_ctx2

    elif isinstance(expr, Invert):
        deriv_expr, ctx, not_ctx, deps = fwd_deriv_transform(
            expr.child, ctx, not_ctx, deps)
        expr = -expr * expr * deriv_expr

    elif isinstance(expr, IfElse):
        if_body, ctx, not_ctx1, _ = fwd_deriv_transform(
            expr.if_body, ctx, not_ctx, deps)
        else_body, ctx, not_ctx2, _ = fwd_deriv_transform(
            expr.else_body, ctx, not_ctx, deps)
        not_ctx = not_ctx1 | not_ctx2

        deltas = Const(0)
        for boolean in primitive_booleans_in(expr.cond, not_ctx, deps):
            jump = substitute(expr, boolean, true) - substitute(
                expr, boolean, false)
            delta_expr = boolean.right_expr - boolean.left_expr

            delta_deriv, ctx, _ignore_not_ctx, _ = fwd_deriv_transform(
                delta_expr, ctx, not_ctx, deps)
            deltas = deltas + delta_deriv * jump * Delta(delta_expr)

        expr = IfElse(expr.cond, if_body, else_body) + deltas

    elif isinstance(expr, Teg):
        assert expr.dvar not in ctx, f'Names of infinitesimal "{expr.dvar}" are distinct from context "{ctx}"'
        #  In int_x f(x), the variable x is in scope for the integrand f(x)
        not_ctx.discard(expr.dvar.name)

        # Include derivative contribution from moving boundaries of integration
        boundary_val, new_ctx, new_not_ctx = boundary_contribution(
            expr, ctx, not_ctx, deps)
        not_ctx.add((expr.dvar.name, expr.dvar.uid))

        body, ctx, not_ctx, _ = fwd_deriv_transform(expr.body, ctx, not_ctx,
                                                    deps)

        ctx.update(new_ctx)
        not_ctx |= new_not_ctx
        expr = Teg(expr.lower, expr.upper, body, expr.dvar) + boundary_val

    elif isinstance(expr, Tup):
        new_expr_list, new_ctx, new_not_ctx = [], Ctx(), set()
        for child in expr:
            child, ctx, not_ctx, _ = fwd_deriv_transform(
                child, ctx, not_ctx, deps)
            new_expr_list.append(child)
            new_ctx.update(ctx)
            new_not_ctx |= not_ctx
        ctx, not_ctx = new_ctx, new_not_ctx
        expr = Tup(*new_expr_list)

    elif isinstance(expr, LetIn):

        # Compute derivatives of each expression and bind them to the corresponding dvar
        new_vars_with_derivs, new_exprs_with_derivs = list(
            expr.new_vars), list(expr.new_exprs)
        new_deps = {}
        for v, e in zip(expr.new_vars, expr.new_exprs):
            if v in expr.expr:
                # By not passing in the updated contexts,
                # we require that assignments in let expressions are independent
                de, ctx, not_ctx, _ = fwd_deriv_transform(
                    e, ctx, not_ctx, deps)
                ctx[(v.name, v.uid)] = Var(f'd{v.name}')
                new_vars_with_derivs.append(ctx[(v.name, v.uid)])
                new_exprs_with_derivs.append(de)
                new_deps[v] = extract_vars(e)

        deps = {**deps, **new_deps}
        # We want an expression in terms of f'd{var_in_let_body}'
        # This means that they are erroniously added to ctx, so we
        # remove them from ctx!
        dexpr, ctx, not_ctx, _ = fwd_deriv_transform(expr.expr, ctx, not_ctx,
                                                     deps)
        [ctx.pop((c.name, c.uid), None) for c in expr.new_vars]

        expr = LetIn(Tup(*new_vars_with_derivs), Tup(*new_exprs_with_derivs),
                     dexpr)

    elif isinstance(expr, BiMap):
        # TODO: is it possible to not repeat this code and make another recursive call instead?

        # Compute derivatives of each expression and bind them to the corresponding dvar
        new_vars_with_derivs, new_exprs_with_derivs = [], []
        for v, e in zip(expr.targets, expr.target_exprs):
            if v in expr.expr:
                # By not passing in the updated contexts, require independence of exprs in the body of the let expression
                de, ctx, not_ctx, _ = fwd_deriv_transform(
                    e, ctx, not_ctx, deps)
                ctx[(v.name, v.uid)] = Var(f'd{v.name}')
                new_vars_with_derivs.append(ctx[(v.name, v.uid)])
                new_exprs_with_derivs.append(de)

                not_ctx = not_ctx | {(v.name, v.uid)}

        # We want an expression in terms of f'd{var_in_let_body}'
        # This means that they are erroniously added to ctx, so we
        # remove them from ctx!
        dexpr, ctx, not_ctx, _ = fwd_deriv_transform(expr.expr, ctx, not_ctx,
                                                     deps)
        [ctx.pop((c.name, c.uid), None) for c in expr.targets]

        expr = LetIn(
            Tup(*new_vars_with_derivs), Tup(*new_exprs_with_derivs),
            BiMap(dexpr,
                  targets=expr.targets,
                  target_exprs=expr.target_exprs,
                  sources=expr.sources,
                  source_exprs=expr.source_exprs,
                  inv_jacobian=expr.inv_jacobian,
                  target_lower_bounds=expr.target_lower_bounds,
                  target_upper_bounds=expr.target_upper_bounds))

    else:
        raise ValueError(
            f'The type of the expr "{type(expr)}" does not have a supported fwd_derivative.'
        )

    return expr, ctx, not_ctx, deps
Beispiel #29
0
    def outer_fn(e, ctx):
        if isinstance(e, BiMap) and (bimap is e):
            if not all([k in ctx['upper_tegvars'] for k in e.sources]):
                # BiMap is invalid, null everything.
                print(
                    f'WARNING: Attempting to map non-Teg vars {e.sources}, {ctx["upper_tegvars"]}'
                )
                return Const(0), ctx

            bounds_checks = reduce(
                operator.and_,
                [(lb < dvar) & (ub > dvar)
                 for (dvar, (lb, ub)) in ctx['source_bounds'].items()])
            reparamaterized_expr = IfElse(bounds_checks,
                                          e.expr * e.inv_jacobian, Const(0))
            return (reparamaterized_expr, {
                **ctx, 'teg_sources': list(e.sources),
                'teg_targets': list(e.targets),
                'let_mappings':
                {s: sexpr
                 for s, sexpr in zip(e.sources, e.source_exprs)},
                'target_lower_bounds':
                {t: tlb
                 for t, tlb in zip(e.targets, e.target_lower_bounds)},
                'target_upper_bounds':
                {t: tub
                 for t, tub in zip(e.targets, e.target_upper_bounds)}
            })
        elif isinstance(e, Teg):

            if e.dvar in ctx.get('teg_sources', {}):
                ctx['teg_sources'].remove(e.dvar)
                target_dvar = ctx['teg_targets'].pop()
                placeholders = {
                    **{
                        f'{svar.uid}_ub': upper
                        for svar, (lower, upper) in ctx['source_bounds'].items(
                        )
                    },
                    **{
                        f'{svar.uid}_lb': lower
                        for svar, (lower, upper) in ctx['source_bounds'].items(
                        )
                    }
                }
                target_lower_bounds = resolve_placeholders(
                    ctx['target_lower_bounds'][target_dvar], placeholders)
                target_upper_bounds = resolve_placeholders(
                    ctx['target_upper_bounds'][target_dvar], placeholders)

                # Add new teg to list.
                ctx['new_tegs'] = [
                    *ctx.get('new_tegs', []),
                    (target_dvar, (target_lower_bounds, target_upper_bounds))
                ]

                # Remove old teg.
                e = e.body

                if len(ctx['teg_sources']) == 0:
                    # Add let mappings here.
                    source_vars, source_exprs = zip(
                        *list(ctx['let_mappings'].items()))
                    e = LetIn(source_vars, source_exprs, e)

                    # Add new tegs here.
                    for (new_dvar, (new_lb, new_ub)) in ctx['new_tegs']:
                        e = Teg(new_lb, new_ub, e, new_dvar)

                    # Add dependent mappings here.
                    for new_vars, new_exprs in ctx.get('dependent_mappings',
                                                       []):
                        e = LetIn(new_vars, new_exprs, e)
                return e, ctx

        elif isinstance(e, LetIn):

            if len(ctx.get('teg_sources', {})) > 0:
                if (any([
                        new_var in map_expr for new_var in e.new_vars
                        for map_vars, map_exprs in ctx.get(
                            'dependent_mappings', []) for map_expr in map_exprs
                ]) or any([
                        new_var in map_expr for new_var in e.new_vars
                        for map_var, map_expr in ctx.get('let_mappings',
                                                         {}).items()
                ])):
                    # reparametrization is dependent on this let_map. lift this map.
                    ctx['dependent_mappings'] = [
                        *ctx.get('dependent_mappings', []),
                        (e.new_vars, e.new_exprs)
                    ]
                    return e.expr, ctx

        return e, ctx
Beispiel #30
0
    def rewrite(delta, not_ctx=set()):
        # Extract polynomial coefficients.
        poly_set = extract_coefficients_from_polynomial(
            delta.expr, {(var.name, var.uid)
                         for var in not_ctx})

        unique_vars = []
        for term in poly_set:
            for var, _ in term:
                if var is not CONST_VAR:
                    unique_vars.append(var)

        x = unique_vars[0]
        y = unique_vars[1]

        c_xy = get_poly_term(poly_set, {x: 1, y: 1})
        c_x = get_poly_term(poly_set, {x: 1})
        c_y = get_poly_term(poly_set, {y: 1})
        c_1 = get_poly_term(poly_set, {})

        c_xy_var = Var(f'c_{x.name}_{y.name}')
        c_x_var = Var(f'c_{x.name}')
        c_y_var = Var(f'c_{y.name}')
        c_1_var = Var('c_1')

        coeff_vars = [c_xy_var, c_x_var, c_y_var, c_1_var]
        coeff_exprs = [
            teg_abs(c_xy),
            IfElse(c_xy > 0, c_x, -c_x),
            IfElse(c_xy > 0, c_y, -c_y),
            IfElse(c_xy > 0, c_1, -c_1)
        ]

        sqrt_c_xy = Sqrt(c_xy_var)
        sqrt_c_xy_var = Var(f'{x.name}_{y.name}_sqrt')

        needs_transforms = (c_x != Const(0) or c_y != Const(0))

        if needs_transforms:
            scale_map = partial(scale, scale=[sqrt_c_xy_var, sqrt_c_xy_var])

            translate_map = partial(
                translate,
                translate=[c_y_var / sqrt_c_xy_var, c_x_var / sqrt_c_xy_var])

            scaler, (x_s, y_s) = scale_map([x, y])
            translater, (x_st, y_st) = translate_map([x_s, y_s])

            sqr_constant = (c_x_var * c_y_var) / (c_xy_var) - c_1_var
            scale_jacobian = Const(1)
        else:
            x_st, y_st = x, y
            sqr_constant = -c_1_var / c_xy_var
            scale_jacobian = c_xy_var

        # If threshold is negative, the hyperbola is in the second and fourth quadrants.
        # Inverting either one of x or y automatically handles this.
        conditional_inverter, (x_st, ) = scale(
            [x_st], scale=[IfElse(sqr_constant > 0, 1, -1)])
        adjusted_sqr_constant = teg_abs(sqr_constant)
        constant = Sqrt(adjusted_sqr_constant)

        # Hyperbolic transform
        hyp_a, hyp_t = TegVar('hyp_a'), TegVar('hyp_t')

        # Build bounds transfer expressions.
        pos_a_lb = teg_cases([Sqrt(x_st.lb() * y_st.lb()),
                              Const(0)], [(x_st.lb() > 0) & (y_st.lb() > 0)])

        pos_a_ub = teg_cases([Sqrt(x_st.ub() * y_st.ub()),
                              Const(0)], [(x_st.ub() > 0) & (y_st.ub() > 0)])

        neg_a_lb = teg_cases([-Sqrt(x_st.lb() * y_st.lb()),
                              Const(0)], [(x_st.lb() < 0) & (y_st.lb() < 0)])
        neg_a_ub = teg_cases([-Sqrt(x_st.ub() * y_st.ub()),
                              Const(0)], [(x_st.ub() < 0) & (y_st.ub() < 0)])

        pos_t_lb = teg_max(
            teg_cases([
                teg_max(x_st.lb() / hyp_a, hyp_a / y_st.ub()),
                hyp_a / y_st.ub(), MIN_T
            ], [(y_st.ub() > 0) & (x_st.lb() > 0),
                y_st.ub() > 0]), MIN_T)

        pos_t_ub = teg_min(
            teg_cases([
                teg_min(x_st.ub() / hyp_a, hyp_a / y_st.lb()),
                x_st.ub() / hyp_a, MAX_T
            ], [(x_st.ub() > 0) & (y_st.lb() > 0),
                x_st.ub() > 0]), MAX_T)

        neg_t_lb = teg_max(
            teg_cases([
                teg_max(x_st.ub() / hyp_a, hyp_a / y_st.lb()),
                hyp_a / y_st.lb(), MIN_T
            ], [(y_st.lb() < 0) & (x_st.ub() < 0),
                y_st.lb() < 0]), MIN_T)
        neg_t_ub = teg_min(
            teg_cases([
                teg_min(x_st.lb() / hyp_a, hyp_a / y_st.ub()),
                x_st.lb() / hyp_a, MAX_T
            ], [(x_st.lb() < 0) & (y_st.ub() < 0),
                x_st.lb() < 0]), MAX_T)

        pos_curve = BiMap(Delta(hyp_a - constant),
                          sources=[x_st, y_st],
                          source_exprs=[hyp_a * hyp_t, hyp_a / hyp_t],
                          targets=[hyp_a, hyp_t],
                          target_exprs=[Sqrt(x_st * y_st),
                                        Sqrt(x_st / y_st)],
                          inv_jacobian=(hyp_a / hyp_t) *
                          (1 / (constant * scale_jacobian)),
                          target_lower_bounds=[pos_a_lb, pos_t_lb],
                          target_upper_bounds=[pos_a_ub, pos_t_ub])

        neg_curve = BiMap(Delta(hyp_a + constant),
                          sources=[x_st, y_st],
                          source_exprs=[hyp_a * hyp_t, hyp_a / hyp_t],
                          targets=[hyp_a, hyp_t],
                          target_exprs=[-Sqrt(x_st * y_st),
                                        Sqrt(x_st / y_st)],
                          inv_jacobian=(-1 * hyp_a / hyp_t) *
                          (1 / (constant * scale_jacobian)),
                          target_lower_bounds=[neg_a_lb, neg_t_lb],
                          target_upper_bounds=[neg_a_ub, neg_t_ub])

        if needs_transforms:
            return LetIn(
                coeff_vars, coeff_exprs,
                LetIn([sqrt_c_xy_var], [sqrt_c_xy],
                      scaler(
                          translater(
                              conditional_inverter(pos_curve + neg_curve)))))
        else:
            return LetIn(coeff_vars, coeff_exprs,
                         conditional_inverter(pos_curve + neg_curve))