示例#1
0
文件: delta.py 项目: PeterZhouSZ/Teg
    def outer_fn(e, ctx):

        ctx['has_expr'] = any(ctx['has_exprs'])
        # Check if we need to handle other such cases.
        assert not (ctx['has_expr'] and isinstance(e, SmoothFunc)),\
               f'expr is contained in a non-linear function {type(e)}'
        if isinstance(e, Add):
            if ctx['has_expr']:
                ctx['expr'] = sum([
                    child
                    for child, has_expr in zip(ctx['exprs'], ctx['has_exprs'])
                    if has_expr
                ])
                return ctx['expr'], ctx
            else:
                ctx['expr'] = e
                return e, ctx
        elif isinstance(e, Tup):
            if ctx['has_expr']:
                ctx['expr'] = Tup(*[
                    ctx['exprs'][idx] if has_expr else Const(0)
                    for idx, has_expr in enumerate(ctx['has_exprs'])
                ])
                return ctx['expr'], ctx
        elif isinstance(e, IfElse):
            if ctx['has_expr']:
                ctx['expr'] = IfElse(
                    e.cond,
                    ctx['exprs'][1] if ctx['has_exprs'][1] else Const(0),
                    ctx['exprs'][2] if ctx['has_exprs'][2] else Const(0))
                return ctx, e
        elif isinstance(e, LetIn):
            if any(ctx['has_exprs'][1:]):
                # Let expressions contain exprs.
                new_exprs = [
                    let_var for let_var, has_expr in zip(
                        e.new_vars, ctx['has_exprs'][1:]) if has_expr
                ]
                # Recursively split the body with the new expressions.
                s_expr = split_exprs(new_exprs, ctx['let_body'])
                let_body = (s_expr if s_expr else Const(0)) +\
                           (e.expr if ctx['has_exprs'][0] else Const(0))
                try:
                    vs, es = zip(*[(v, e)
                                   for v, e in zip(e.new_vars, e.new_exprs)
                                   if v in let_body])
                    ctx['expr'] = LetIn(vs, es, let_body)
                except ValueError:
                    # No need for a let expr.
                    ctx['expr'] = let_body

                return ctx['expr'], ctx

        ctx['expr'] = e
        return ctx['expr'], ctx
示例#2
0
    def derivs_for_single_outval(
            expr: ITeg,
            single_outval: Const,
            i: Optional[int] = None,
            output_list: Optional[List[Var]] = None,
            args: Dict[str, Any] = None) -> Tuple[List[Var], ITeg]:

        partial_deriv_map = defaultdict(lambda: Const(0))

        # After deriv_transform, expr will have unbound infinitesimals
        for name_uid, e in reverse_deriv_transform(expr, single_outval, set(),
                                                   {}, args):
            partial_deriv_map[name_uid] += e

        # Introduce fresh variables for each partial derivative
        uids = [var_uid for var_name, var_uid in partial_deriv_map.keys()]
        new_vars = [
            Var(var_name) for var_name, var_uid in partial_deriv_map.keys()
        ]
        new_vals = [*partial_deriv_map.values()]

        if output_list is not None:
            # Return requested list of outputs.
            var_map = {
                uid: (var, val)
                for uid, var, val in zip(uids, new_vars, new_vals)
            }
            new_vars, new_vals = zip(*[
                var_map.get(var.uid, (Var(f'd{var.name}'), Const(0)))
                for var in output_list
            ])
        else:
            # Return a list sorted in the order the variables were defined.
            sorted_list = list(zip(uids, new_vars, new_vals))
            sorted_list.sort(key=lambda a: a[0])
            _, new_vars, new_vals = list(zip(*sorted_list))

        assert len(
            new_vals) > 0, 'There must be variables to compute derivatives. '
        return new_vars, (Tup(*new_vals) if len(new_vars) > 1 else new_vals[0])
示例#3
0
def simplify(expr: ITeg) -> ITeg:

    if isinstance(expr, Var):
        return expr

    elif isinstance(expr, Add):
        expr1, expr2 = expr.children
        simple1, simple2 = simplify(expr1), simplify(expr2)
        if isinstance(simple1, Const) and simple1.value == 0:
            return simple2
        if isinstance(simple2, Const) and simple2.value == 0:
            return simple1
        if isinstance(simple1, Const) and isinstance(simple2, Const):
            return Const(evaluate(simple1 + simple2))

        # Associative reordering.
        if isinstance(simple1,
                      (Add, Const)) and isinstance(simple2, (Add, Const)):
            nodes1 = [
                simple1,
            ] if isinstance(simple1, Const) else simple1.children
            nodes2 = [
                simple2,
            ] if isinstance(simple2, Const) else simple2.children
            all_nodes = nodes1 + nodes2
            assert 2 <= len(
                all_nodes
            ) <= 4, 'Unexpected number of nodes in Add-associative tree'

            const_nodes = [
                node for node in all_nodes if isinstance(node, Const)
            ]
            other_nodes = [
                node for node in all_nodes if not isinstance(node, Const)
            ]

            # No const nodes -> Reordering is pointless.
            if len(other_nodes) == len(all_nodes):
                return simple1 + simple2

            # Compress const nodes.
            const_node = Const(evaluate(reduce(operator.add, const_nodes)))

            # Re-order to front.
            if const_node == Const(0):
                simplified_nodes = other_nodes
            else:
                simplified_nodes = other_nodes + [const_node]

            # Build tree in reverse (so const node is at top level)
            return reduce(operator.add, simplified_nodes)

        if isinstance(simple1, LetIn) and isinstance(simple2, LetIn):
            if simple1.new_vars == simple2.new_vars and simple1.new_exprs == simple2.new_exprs:
                return LetIn(new_vars=simple1.new_vars,
                             new_exprs=simple1.new_exprs,
                             expr=simplify(simple1.expr + simple2.expr))
            else:
                return simple1 + simple2

        if isinstance(simple1, Teg) and isinstance(simple2, Teg):
            if (simple1.dvar == simple2.dvar and simple1.lower == simple2.lower
                    and simple1.upper == simple2.upper):
                return simplify(
                    Teg(simple1.lower, simple1.upper,
                        simplify(simple1.body + simple2.body), simple1.dvar))

            else:
                return simple1 + simple2

        if isinstance(simple1, IfElse) and isinstance(simple2, IfElse):
            if simple1.cond == simple2.cond:
                return IfElse(simple1.cond,
                              simplify(simple1.if_body + simple2.if_body),
                              simplify(simple1.else_body + simple2.else_body))
            else:
                return simple1 + simple2

        if isinstance(simple1, Mul) and isinstance(simple2, Mul):
            # Distribution.
            exprLL, exprLR = simple1.children
            exprRL, exprRR = simple2.children

            if exprLL == exprRR:
                return simplify(exprLL * (simplify(exprLR + exprRL)))
            if exprLL == exprRL:
                return simplify(exprLL * (simplify(exprLR + exprRR)))
            if exprLR == exprRL:
                return simplify(exprLR * (simplify(exprLL + exprRR)))
            if exprLR == exprRR:
                return simplify(exprLR * (simplify(exprLL + exprRL)))

        return simple1 + simple2

    elif isinstance(expr, Mul):
        expr1, expr2 = expr.children
        simple1, simple2 = simplify(expr1), simplify(expr2)

        # 0-elimination
        if ((isinstance(simple1, Const) and simple1.value == 0)
                or (isinstance(simple2, Const) and hasattr(simple2, 'value')
                    and simple2.value == 0)):
            return Const(0)

        # Multiplicative inverse.
        if isinstance(simple1, Const) and simple1.value == 1.0:
            return simple2
        if isinstance(simple2, Const) and simple2.value == 1.0:
            return simple1

        # Local constant compression.
        if isinstance(simple1, Const) and isinstance(simple2, Const):
            return Const(evaluate(simple1 * simple2))

        # Associative reordering.
        if isinstance(simple1,
                      (Mul, Const)) and isinstance(simple2, (Mul, Const)):
            nodes1 = [simple1] if isinstance(simple1,
                                             Const) else simple1.children
            nodes2 = [simple2] if isinstance(simple2,
                                             Const) else simple2.children
            all_nodes = nodes1 + nodes2
            assert 2 <= len(
                all_nodes
            ) <= 4, 'Unexpected number of nodes in Mul-associative tree'

            const_nodes = [
                node for node in all_nodes if isinstance(node, Const)
            ]
            other_nodes = [
                node for node in all_nodes if not isinstance(node, Const)
            ]

            # No const nodes -> Reordering is pointless.
            if len(other_nodes) == len(all_nodes):
                return simple1 * simple2

            # Compress const nodes.
            const_node = Const(evaluate(reduce(operator.mul, const_nodes)))

            # Re-order to front.
            if not (const_node == Const(1)):
                simplified_nodes = other_nodes + [const_node]
            else:
                simplified_nodes = other_nodes

            # Build tree in reverse (so const node is at top level)
            return reduce(operator.mul, simplified_nodes)

        return simple1 * simple2

    elif isinstance(expr, Invert):
        simple = simplify(expr.child)
        if isinstance(simple, Const):
            return Const(evaluate(Invert(simple)))
        return Invert(simple)

    elif isinstance(expr, SmoothFunc):
        simple = simplify(expr.expr)
        if isinstance(simple, Const):
            return Const(evaluate(type(expr)(simple)))
        return type(expr)(simplify(expr.expr))

    elif isinstance(expr, IfElse):
        cond, if_body, else_body = simplify(expr.cond), simplify(
            expr.if_body), simplify(expr.else_body)
        if (isinstance(if_body, Const) and isinstance(else_body, Const)
                and if_body.value == 0 and else_body.value == 0):
            return if_body

        if cond == true:
            return if_body

        if cond == false:
            return else_body

        return IfElse(cond, if_body, else_body)

    elif isinstance(expr, Teg):
        body = simplify(expr.body)
        if isinstance(body, Const) and hasattr(body,
                                               'value') and body.value == 0:
            return Const(0)
        return Teg(simplify(expr.lower), simplify(expr.upper), body, expr.dvar)

    elif isinstance(expr, Tup):
        return Tup(*(simplify(child) for child in expr))

    elif isinstance(expr, LetIn):
        simplified_exprs = Tup(*(simplify(e) for e in expr.new_exprs))
        child_expr = simplify(expr.expr)
        vars_list = expr.new_vars

        for s_var, s_expr in zip(vars_list, simplified_exprs):
            if isinstance(s_expr, Const):
                child_expr = substitute(child_expr, s_var, s_expr)

        non_const_bindings = [
            (s_var, s_expr)
            for s_var, s_expr in zip(vars_list, simplified_exprs)
            if not isinstance(s_expr, Const)
        ]

        child_expr = simplify(child_expr)
        if non_const_bindings:
            non_const_vars, non_const_exprs = zip(*list(non_const_bindings))
            return (LetIn(non_const_vars, non_const_exprs, child_expr)
                    if not isinstance(child_expr, Const) else child_expr)
        else:
            return child_expr

    elif isinstance(expr, BiMap):
        simplified_target_exprs = list(simplify(e) for e in expr.target_exprs)
        simplified_source_exprs = list(simplify(e) for e in expr.source_exprs)

        simplified_ubs = list(simplify(e) for e in expr.target_upper_bounds)
        simplified_lbs = list(simplify(e) for e in expr.target_lower_bounds)

        child_expr = simplify(expr.expr)

        return BiMap(expr=child_expr,
                     targets=expr.targets,
                     target_exprs=simplified_target_exprs,
                     sources=expr.sources,
                     source_exprs=simplified_source_exprs,
                     inv_jacobian=simplify(expr.inv_jacobian),
                     target_lower_bounds=simplified_lbs,
                     target_upper_bounds=simplified_ubs)

    elif isinstance(expr, Delta):
        return Delta(simplify(expr.expr))

    elif {'FwdDeriv', 'RevDeriv'} & {t.__name__ for t in type(expr).__mro__}:
        return simplify(expr.__getattribute__('deriv_expr'))

    elif isinstance(expr, Bool):
        left_expr, right_expr = simplify(expr.left_expr), simplify(
            expr.right_expr)
        if isinstance(left_expr, Const) and isinstance(right_expr, Const):
            return false if evaluate(Bool(left_expr,
                                          right_expr)) == 0.0 else true
        return Bool(left_expr, right_expr)

    elif isinstance(expr, And):
        left_expr, right_expr = simplify(expr.left_expr), simplify(
            expr.right_expr)
        if left_expr == true:
            return right_expr
        if right_expr == true:
            return left_expr
        if left_expr == false or right_expr == false:
            return false
        return And(left_expr, right_expr)

    elif isinstance(expr, Or):
        left_expr, right_expr = simplify(expr.left_expr), simplify(
            expr.right_expr)
        if left_expr == false:
            return right_expr
        if right_expr == false:
            return left_expr
        if left_expr == true or right_expr == true:
            return true
        return Or(left_expr, right_expr)

    else:
        raise ValueError(
            f'The type of the expr "{type(expr)}" does not have a supported simplify rule'
        )
示例#4
0
def do_pass(expr: ITeg, context, inner_fn, outer_fn,
            context_combine) -> Tuple[ITeg, Dict]:
    """Substitute this_var with that_var in expr."""

    if isinstance(expr, Const):
        expr, out_context = outer_fn(expr, context_combine([], context))
        return expr, out_context

    elif isinstance(expr, (Var, TegVar, Placeholder)):
        expr, out_context = outer_fn(expr, context_combine([], context))
        return expr, out_context

    elif isinstance(expr, Add):
        left_expr, left_ctx = do_pass(*inner_fn(expr.children[0], context),
                                      inner_fn, outer_fn, context_combine)
        right_expr, right_ctx = do_pass(*inner_fn(expr.children[1], context),
                                        inner_fn, outer_fn, context_combine)

        return outer_fn(
            expr if (left_expr is expr.children[0]) and
            (right_expr is expr.children[1]) else left_expr + right_expr,
            context_combine([left_ctx, right_ctx], context))

    elif isinstance(expr, Mul):
        left_expr, left_ctx = do_pass(*inner_fn(expr.children[0], context),
                                      inner_fn, outer_fn, context_combine)
        right_expr, right_ctx = do_pass(*inner_fn(expr.children[1], context),
                                        inner_fn, outer_fn, context_combine)

        return outer_fn(
            expr if (left_expr is expr.children[0]) and
            (right_expr is expr.children[1]) else left_expr * right_expr,
            context_combine([left_ctx, right_ctx], context))

    elif isinstance(expr, Invert):
        child, child_ctx = do_pass(*inner_fn(expr.child, context), inner_fn,
                                   outer_fn, context_combine)

        return outer_fn(
            Invert(child) if expr.child is not child else expr,
            context_combine([child_ctx], context))

    elif isinstance(expr, SmoothFunc):
        child, child_ctx = do_pass(*inner_fn(expr.expr, context), inner_fn,
                                   outer_fn, context_combine)

        return outer_fn(
            type(expr)(child) if expr.expr is not child else expr,
            context_combine([child_ctx], context))

    elif isinstance(expr, IfElse):
        cond, cond_ctx = do_pass(*inner_fn(expr.cond, context), inner_fn,
                                 outer_fn, context_combine)
        if_body, if_ctx = do_pass(*inner_fn(expr.if_body, context), inner_fn,
                                  outer_fn, context_combine)
        else_body, else_ctx = do_pass(*inner_fn(expr.else_body, context),
                                      inner_fn, outer_fn, context_combine)
        expr = IfElse(cond, if_body, else_body) if (
            cond is not expr.cond or if_body is not expr.if_body
            or else_body is not expr.else_body) else expr

        return outer_fn(expr,
                        context_combine([cond_ctx, if_ctx, else_ctx], context))

    elif isinstance(expr, Teg):
        # dvar, dvar_ctx = do_pass(*inner_fn(expr.dvar, context), inner_fn, outer_fn, context_combine)
        body, body_ctx = do_pass(*inner_fn(expr.body, context), inner_fn,
                                 outer_fn, context_combine)
        lower, lower_ctx = do_pass(*inner_fn(expr.lower, context), inner_fn,
                                   outer_fn, context_combine)
        upper, upper_ctx = do_pass(*inner_fn(expr.upper, context), inner_fn,
                                   outer_fn, context_combine)

        expr = expr if (body is expr.body and lower is expr.lower
                        and upper is expr.upper) else Teg(
                            lower, upper, body, expr.dvar)

        return outer_fn(
            expr, context_combine([lower_ctx, upper_ctx, body_ctx], context))

    elif isinstance(expr, Tup):
        exprs, expr_contexts = zip(*[(do_pass(*inner_fn(
            child, context), inner_fn, outer_fn, context_combine))
                                     for child in expr])
        expr = expr if all([
            new_child is old_child
            for new_child, old_child in zip(exprs, expr)
        ]) else Tup(*exprs)

        return outer_fn(expr, context_combine(expr_contexts, context))

    elif isinstance(expr, LetIn):
        body_expr, body_context = do_pass(*inner_fn(expr.expr, context),
                                          inner_fn, outer_fn, context_combine)
        let_exprs, let_contexts = zip(*[
            do_pass(*inner_fn(child, context), inner_fn, outer_fn,
                    context_combine) for child in expr.new_exprs
        ])
        expr = expr if (all([
            new_child is old_child
            for new_child, old_child in zip(let_exprs, expr.new_exprs)
        ]) and body_expr is expr.expr) else LetIn(expr.new_vars, let_exprs,
                                                  body_expr)

        return outer_fn(
            expr, context_combine([body_context, *let_contexts], context))

    elif isinstance(expr, Bool):
        left_expr, left_ctx = do_pass(*inner_fn(expr.left_expr, context),
                                      inner_fn, outer_fn, context_combine)
        right_expr, right_ctx = do_pass(*inner_fn(expr.right_expr, context),
                                        inner_fn, outer_fn, context_combine)

        expr = expr if (left_expr is expr.left_expr
                        and right_expr is expr.right_expr) else Bool(
                            left_expr, right_expr, allow_eq=expr.allow_eq)
        return outer_fn(expr, context_combine([left_ctx, right_ctx], context))

    elif isinstance(expr, (And, Or)):
        left_expr, left_ctx = do_pass(*inner_fn(expr.left_expr, context),
                                      inner_fn, outer_fn, context_combine)
        right_expr, right_ctx = do_pass(*inner_fn(expr.right_expr, context),
                                        inner_fn, outer_fn, context_combine)

        expr = expr if (left_expr is expr.left_expr
                        and right_expr is expr.right_expr) else type(expr)(
                            left_expr, right_expr)

        return outer_fn(expr, context_combine([left_ctx, right_ctx], context))

    elif isinstance(expr, BiMap):
        body_expr, body_context = do_pass(*inner_fn(expr.expr, context),
                                          inner_fn, outer_fn, context_combine)

        source_exprs, source_contexts = zip(*[
            do_pass(*inner_fn(child, context), inner_fn, outer_fn,
                    context_combine) for child in expr.source_exprs
        ])
        target_exprs, target_contexts = zip(*[
            do_pass(*inner_fn(child, context), inner_fn, outer_fn,
                    context_combine) for child in expr.target_exprs
        ])
        jacobian_expr, jacobian_context = do_pass(
            *inner_fn(expr.inv_jacobian, context), inner_fn, outer_fn,
            context_combine)
        target_upper_bounds, ub_contexts = zip(*[
            do_pass(*inner_fn(child, context), inner_fn, outer_fn,
                    context_combine) for child in expr.target_upper_bounds
        ])
        target_lower_bounds, lb_contexts = zip(*[
            do_pass(*inner_fn(child, context), inner_fn, outer_fn,
                    context_combine) for child in expr.target_lower_bounds
        ])

        expr = expr if (all([
            new_child is old_child
            for new_child, old_child in zip(source_exprs, expr.source_exprs)
        ]) and all([
            new_child is old_child
            for new_child, old_child in zip(target_exprs, expr.target_exprs)
        ]) and all([
            new_child is old_child for new_child, old_child in zip(
                target_upper_bounds, expr.target_upper_bounds)
        ]) and all([
            new_child is old_child for new_child, old_child in zip(
                target_lower_bounds, expr.target_lower_bounds)
        ]) and body_expr is expr.expr
                        and jacobian_expr is expr.inv_jacobian) else BiMap(
                            body_expr, expr.targets, target_exprs,
                            expr.sources, source_exprs, jacobian_expr,
                            target_upper_bounds, target_lower_bounds)

        return outer_fn(
            expr,
            context_combine([
                *source_contexts, *target_contexts, jacobian_context,
                *ub_contexts, *lb_contexts, body_context
            ], context))

    elif isinstance(expr, Delta):
        body_expr, body_context = do_pass(*inner_fn(expr.expr, context),
                                          inner_fn, outer_fn, context_combine)

        expr = expr if body_expr is expr.expr else Delta(body_expr)

        return outer_fn(expr, context_combine([body_context], context))

    else:
        raise ValueError(
            f'The type of the expr "{type(expr)}" is not supported by substitute.'
        )
示例#5
0
    t3, t4 = Var('t3'), Var('t4')

    scale_map, (x_s, y_s) = scale([x, y], [t1, t2])
    translate_map, (x_st, y_st) = translate([x_s, y_s], [t3, t4])

    # Area of a unit circle.
    bindings = {t1: 1, t2: 1, t3: 0, t4: 0, t: 0.25}
    # Derivative of threshold only.
    integral = Teg(
                x_lb, x_ub,
                Teg(y_lb, y_ub,
                    scale_map(translate_map(IfElse(x_st * y_st > t, 1, 0))), y
                    ), x
                )

    d_vars, dt_exprs = reverse_deriv(integral, Tup(Const(1)), output_list=[t, t1, t2, t3, t4])

    integral = reduce_to_base(integral)
    image = render_image(integral,
                         variables=((x_lb, x_ub), (y_lb, y_ub)),
                         bindings=bindings,
                         bounds=((-1, 1), (-1, 1)),
                         res=(args.res_x, args.res_y),
                         )
    save_image(np.abs(image), filename=f'{args.testname}.png')

    for d_var, dt_expr in zip(d_vars, dt_exprs):
        image = render_image(reduce_to_base(dt_expr),
                             variables=((x_lb, x_ub), (y_lb, y_ub)),
                             bindings=bindings,
                             bounds=((-1, 1), (-1, 1)),
示例#6
0
def reverse_deriv(expr: ITeg,
                  out_deriv_vals: Tup = None,
                  output_list: Optional[List[Var]] = None,
                  args: Dict[str, Any] = None) -> ITeg:
    """Computes the derivative of a given expression.

    Args:
        expr: The expression to compute the total derivative of.
        out_deriv_vals: A mapping from variable names to the values of corresponding infinitesimals.
        args: Additional mappings for specifying alternative behavior such as 'ignore_deltas' and 'ignore_bounds'.

    Returns:
        ITeg: The reverse derivative expression in the extended language.
    """

    if out_deriv_vals is None:
        out_deriv_vals = Tup(Const(1))

    if args is None:
        args = {}

    def derivs_for_single_outval(
            expr: ITeg,
            single_outval: Const,
            i: Optional[int] = None,
            output_list: Optional[List[Var]] = None,
            args: Dict[str, Any] = None) -> Tuple[List[Var], ITeg]:

        partial_deriv_map = defaultdict(lambda: Const(0))

        # After deriv_transform, expr will have unbound infinitesimals
        for name_uid, e in reverse_deriv_transform(expr, single_outval, set(),
                                                   {}, args):
            partial_deriv_map[name_uid] += e

        # Introduce fresh variables for each partial derivative
        uids = [var_uid for var_name, var_uid in partial_deriv_map.keys()]
        new_vars = [
            Var(var_name) for var_name, var_uid in partial_deriv_map.keys()
        ]
        new_vals = [*partial_deriv_map.values()]

        if output_list is not None:
            # Return requested list of outputs.
            var_map = {
                uid: (var, val)
                for uid, var, val in zip(uids, new_vars, new_vals)
            }
            new_vars, new_vals = zip(*[
                var_map.get(var.uid, (Var(f'd{var.name}'), Const(0)))
                for var in output_list
            ])
        else:
            # Return a list sorted in the order the variables were defined.
            sorted_list = list(zip(uids, new_vars, new_vals))
            sorted_list.sort(key=lambda a: a[0])
            _, new_vars, new_vals = list(zip(*sorted_list))

        assert len(
            new_vals) > 0, 'There must be variables to compute derivatives. '
        return new_vars, (Tup(*new_vals) if len(new_vars) > 1 else new_vals[0])

    if len(out_deriv_vals) == 1:
        single_outval = out_deriv_vals.children[0]
        derivs = derivs_for_single_outval(expr,
                                          single_outval,
                                          0,
                                          output_list=output_list,
                                          args=args)
    else:
        assert len(out_deriv_vals) == len(expr), \
            f'Expected out_deriv to have "{len(expr)}" values, but got "{len(out_deriv_vals)}" values.'

        derivs = (
            derivs_for_single_outval(e,
                                     single_outval,
                                     i,
                                     output_list=output_list,
                                     args=args)
            for i, (e, single_outval) in enumerate(zip(expr, out_deriv_vals)))
        derivs = Tup(*derivs)
    return derivs
示例#7
0
def fwd_deriv_transform(
    expr: ITeg, ctx: Dict[Tuple[str, int], ITeg],
    not_ctx: Set[Tuple[str, int]], deps: Dict[TegVar, Set[Var]]
) -> Tuple[ITeg, Dict[Tuple[str, int], str], Set[Tuple[str, int]]]:
    """Compute the source-to-source foward derivative of the given expression."""
    if isinstance(expr, TegVar):
        if (((expr.name, expr.uid) not in not_ctx
             or {(v.name, v.uid)
                 for v in extend_dependencies({expr}, deps)} - not_ctx)
                and (expr.name, expr.uid) in ctx):
            expr = ctx[(expr.name, expr.uid)]
        else:
            expr = Const(0)

    elif isinstance(expr, (Const, Placeholder, Delta)):
        expr = Const(0)

    elif isinstance(expr, Var):
        if (expr.name, expr.uid) not in not_ctx and (expr.name,
                                                     expr.uid) in ctx:
            expr = ctx[(expr.name, expr.uid)]
        else:
            expr = Const(0)

    elif isinstance(expr, SmoothFunc):
        in_deriv_expr, ctx, not_ctx, deps = fwd_deriv_transform(
            expr.expr, ctx, not_ctx, deps)
        deriv_expr = expr.fwd_deriv(in_deriv_expr=in_deriv_expr)
        expr = deriv_expr

    elif isinstance(expr, Add):
        sum_of_derivs = Const(0)
        for child in expr.children:
            deriv_child, ctx, not_ctx, deps = fwd_deriv_transform(
                child, ctx, not_ctx, deps)
            sum_of_derivs += deriv_child

        expr = sum_of_derivs

    elif isinstance(expr, Mul):
        # NOTE: Consider n-ary multiplication.
        assert len(
            expr.children
        ) == 2, 'fwd_deriv only supports binary multiplication not n-ary.'
        expr1, expr2 = [child for child in expr.children]

        (deriv_expr1, ctx1, not_ctx1,
         _) = fwd_deriv_transform(expr1, ctx, not_ctx, deps)
        (deriv_expr2, ctx2, not_ctx2,
         _) = fwd_deriv_transform(expr2, ctx, not_ctx, deps)

        expr = expr1 * deriv_expr2 + expr2 * deriv_expr1
        ctx = {**ctx1, **ctx2}
        not_ctx = not_ctx1 | not_ctx2

    elif isinstance(expr, Invert):
        deriv_expr, ctx, not_ctx, deps = fwd_deriv_transform(
            expr.child, ctx, not_ctx, deps)
        expr = -expr * expr * deriv_expr

    elif isinstance(expr, IfElse):
        if_body, ctx, not_ctx1, _ = fwd_deriv_transform(
            expr.if_body, ctx, not_ctx, deps)
        else_body, ctx, not_ctx2, _ = fwd_deriv_transform(
            expr.else_body, ctx, not_ctx, deps)
        not_ctx = not_ctx1 | not_ctx2

        deltas = Const(0)
        for boolean in primitive_booleans_in(expr.cond, not_ctx, deps):
            jump = substitute(expr, boolean, true) - substitute(
                expr, boolean, false)
            delta_expr = boolean.right_expr - boolean.left_expr

            delta_deriv, ctx, _ignore_not_ctx, _ = fwd_deriv_transform(
                delta_expr, ctx, not_ctx, deps)
            deltas = deltas + delta_deriv * jump * Delta(delta_expr)

        expr = IfElse(expr.cond, if_body, else_body) + deltas

    elif isinstance(expr, Teg):
        assert expr.dvar not in ctx, f'Names of infinitesimal "{expr.dvar}" are distinct from context "{ctx}"'
        #  In int_x f(x), the variable x is in scope for the integrand f(x)
        not_ctx.discard(expr.dvar.name)

        # Include derivative contribution from moving boundaries of integration
        boundary_val, new_ctx, new_not_ctx = boundary_contribution(
            expr, ctx, not_ctx, deps)
        not_ctx.add((expr.dvar.name, expr.dvar.uid))

        body, ctx, not_ctx, _ = fwd_deriv_transform(expr.body, ctx, not_ctx,
                                                    deps)

        ctx.update(new_ctx)
        not_ctx |= new_not_ctx
        expr = Teg(expr.lower, expr.upper, body, expr.dvar) + boundary_val

    elif isinstance(expr, Tup):
        new_expr_list, new_ctx, new_not_ctx = [], Ctx(), set()
        for child in expr:
            child, ctx, not_ctx, _ = fwd_deriv_transform(
                child, ctx, not_ctx, deps)
            new_expr_list.append(child)
            new_ctx.update(ctx)
            new_not_ctx |= not_ctx
        ctx, not_ctx = new_ctx, new_not_ctx
        expr = Tup(*new_expr_list)

    elif isinstance(expr, LetIn):

        # Compute derivatives of each expression and bind them to the corresponding dvar
        new_vars_with_derivs, new_exprs_with_derivs = list(
            expr.new_vars), list(expr.new_exprs)
        new_deps = {}
        for v, e in zip(expr.new_vars, expr.new_exprs):
            if v in expr.expr:
                # By not passing in the updated contexts,
                # we require that assignments in let expressions are independent
                de, ctx, not_ctx, _ = fwd_deriv_transform(
                    e, ctx, not_ctx, deps)
                ctx[(v.name, v.uid)] = Var(f'd{v.name}')
                new_vars_with_derivs.append(ctx[(v.name, v.uid)])
                new_exprs_with_derivs.append(de)
                new_deps[v] = extract_vars(e)

        deps = {**deps, **new_deps}
        # We want an expression in terms of f'd{var_in_let_body}'
        # This means that they are erroniously added to ctx, so we
        # remove them from ctx!
        dexpr, ctx, not_ctx, _ = fwd_deriv_transform(expr.expr, ctx, not_ctx,
                                                     deps)
        [ctx.pop((c.name, c.uid), None) for c in expr.new_vars]

        expr = LetIn(Tup(*new_vars_with_derivs), Tup(*new_exprs_with_derivs),
                     dexpr)

    elif isinstance(expr, BiMap):
        # TODO: is it possible to not repeat this code and make another recursive call instead?

        # Compute derivatives of each expression and bind them to the corresponding dvar
        new_vars_with_derivs, new_exprs_with_derivs = [], []
        for v, e in zip(expr.targets, expr.target_exprs):
            if v in expr.expr:
                # By not passing in the updated contexts, require independence of exprs in the body of the let expression
                de, ctx, not_ctx, _ = fwd_deriv_transform(
                    e, ctx, not_ctx, deps)
                ctx[(v.name, v.uid)] = Var(f'd{v.name}')
                new_vars_with_derivs.append(ctx[(v.name, v.uid)])
                new_exprs_with_derivs.append(de)

                not_ctx = not_ctx | {(v.name, v.uid)}

        # We want an expression in terms of f'd{var_in_let_body}'
        # This means that they are erroniously added to ctx, so we
        # remove them from ctx!
        dexpr, ctx, not_ctx, _ = fwd_deriv_transform(expr.expr, ctx, not_ctx,
                                                     deps)
        [ctx.pop((c.name, c.uid), None) for c in expr.targets]

        expr = LetIn(
            Tup(*new_vars_with_derivs), Tup(*new_exprs_with_derivs),
            BiMap(dexpr,
                  targets=expr.targets,
                  target_exprs=expr.target_exprs,
                  sources=expr.sources,
                  source_exprs=expr.source_exprs,
                  inv_jacobian=expr.inv_jacobian,
                  target_lower_bounds=expr.target_lower_bounds,
                  target_upper_bounds=expr.target_upper_bounds))

    else:
        raise ValueError(
            f'The type of the expr "{type(expr)}" does not have a supported fwd_derivative.'
        )

    return expr, ctx, not_ctx, deps
示例#8
0
文件: polar.py 项目: PeterZhouSZ/Teg
def polar_2d_map(expr, x, y, r):
    """
    Create a polar 2D map with x=0, y=0 as center and negative y axis as 0 & 2PI
    """
    theta = TegVar('theta')

    distance_to_origin = Sqrt(
        Sqr((y.lb() + y.ub()) / 2) + Sqr((x.lb() + x.ub()) / 2))
    box_radius = Sqrt(Sqr((y.ub() - y.lb()) / 2) + Sqr((x.ub() - x.lb()) / 2))

    # Manual interval arithmetic for conservative polar bounds.
    # These are not strictly necessary.
    # Using (0,2pi) still produces the correct unbiased integrals.
    # However, they will have terrible sample behaviour)

    box_upper_right = ATan2(Tup(x.ub(), y.ub()))
    box_lower_right = ATan2(Tup(x.ub(), y.lb()))
    box_upper_left = ATan2(Tup(x.lb(), y.ub()))
    box_lower_left = ATan2(Tup(x.lb(), y.lb()))

    right_theta_lower = IfElse(y.ub() > 0, IfElse(x.lb() > 0, box_upper_left,
                                                  0), box_upper_right)
    right_theta_upper = IfElse(y.lb() < 0,
                               IfElse(x.lb() > 0, box_lower_left, TEG_PI),
                               box_lower_right)

    left_theta_upper = IfElse(y.ub() > 0, IfElse(x.ub() < 0, box_upper_right,
                                                 0), box_upper_left)
    left_theta_lower = IfElse(
        y.lb() < 0, IfElse(x.ub() < 0, box_lower_right, TEG_NEGATIVE_PI),
        box_lower_left)

    return IfElse(
        x > 0,
        BiMap(expr,
              sources=[x, y],
              source_exprs=[r * Sin(theta), r * Cos(theta)],
              targets=[r, theta],
              target_exprs=[Sqrt(Sqr(x) + Sqr(y)),
                            ATan2(Tup(x, y))],
              inv_jacobian=r,
              target_lower_bounds=[
                  teg_max(distance_to_origin - box_radius, 0),
                  right_theta_lower
              ],
              target_upper_bounds=[
                  distance_to_origin + box_radius, right_theta_upper
              ]),
        BiMap(expr,
              sources=[x, y],
              source_exprs=[r * Sin(theta), r * Cos(theta)],
              targets=[r, theta],
              target_exprs=[Sqrt(Sqr(x) + Sqr(y)),
                            ATan2(Tup(x, y))],
              inv_jacobian=r,
              target_lower_bounds=[
                  teg_max(distance_to_origin - box_radius, 0), left_theta_lower
              ],
              target_upper_bounds=[
                  distance_to_origin + box_radius, left_theta_upper
              ]))
示例#9
0
文件: delta.py 项目: PeterZhouSZ/Teg
    def outer_fn(e, ctx):

        ctx['has_expr'] = any(ctx['has_exprs'])
        # Check if we need to handle other such cases.
        assert not (ctx['has_expr'] and isinstance(e, SmoothFunc)),\
               f'expr is contained in a non-linear function {type(e)}'
        if isinstance(e, Add):
            if ctx['has_expr']:
                assert sum(
                    ctx['has_exprs']) == 1, 'More than one branch with expr'
                ctx['expr'] = ctx['exprs'][ctx['has_exprs'].index(True)]
                return ctx['expr'], ctx
            else:
                ctx['expr'] = e
                return e, ctx
        elif isinstance(e, IfElse):
            if ctx['has_expr']:
                assert sum(
                    ctx['has_exprs']) == 1, 'More than one branch with expr'
                if ctx['has_exprs'][1]:
                    # If block contains expr.
                    ctx['expr'] = IfElse(e.cond, ctx['exprs'][1], Const(0))
                elif ctx['has_exprs'][2]:
                    ctx['expr'] = IfElse(e.cond, Const(0), ctx['exprs'][2])
                else:
                    assert False, 'condition must not contain expr. expr is not linear'

                return ctx['expr'], ctx
        elif isinstance(e, Tup):
            if ctx['has_expr']:
                assert sum(
                    ctx['has_exprs']) == 1, 'More than one branch with expr'
                ctx['expr'] = Tup(*[
                    ctx['exprs'][idx] if has_expr else Const(0)
                    for idx, has_expr in enumerate(ctx['has_exprs'])
                ])
                return ctx['expr'], ctx
        elif isinstance(e, LetIn):
            assert sum(ctx['has_exprs']) <= 1, 'More than one branch with expr'
            if any(ctx['has_exprs'][1:]):
                # Let expressions contain exprs.
                new_exprs = [
                    let_var for let_var, has_expr in zip(
                        e.new_vars, ctx['has_exprs'][1:]) if has_expr
                ]

                # Recursively split the body with the new expressions.
                s_expr = split_exprs(new_exprs, ctx['let_body'])
                let_body = (s_expr if s_expr else Const(0)) +\
                           (e.expr if ctx['has_exprs'][0] else Const(0))
                try:
                    vs, es = zip(*[(v, e)
                                   for v, e in zip(e.new_vars, e.new_exprs)
                                   if v in let_body])
                    ctx['expr'] = LetIn(vs, es, let_body)
                except ValueError:
                    # No need for a let expr.
                    ctx['expr'] = let_body

                return ctx['expr'], ctx

        ctx['expr'] = e
        return ctx['expr'], ctx