def rotate_to_source(linear: Dict[Tuple[str, int], ITeg], target_vars: List[TegVar], source_vars: List[TegVar]) -> List[ITeg]: """Generates the set of expressions for the source variables in terms of the rotated targets. See Appendix A for details. """ rotation = [] num_vars = len(target_vars) exprs = [linear[(s_var.name, s_var.uid)] for s_var in source_vars] for source_index in range(num_vars): if source_index == 0: rotation.append( sum((Const(1) if i == 0 else Const(-1)) * exprs[i] * target_vars[i] for i in range(num_vars))) elif source_index < len(linear): i = source_index inverse_rotation = sum( ((Const(1) if i == j else Const(0)) - (exprs[i] * exprs[j]) / (1 + exprs[0])) * target_vars[j] for j in range(1, num_vars)) rotation.append(inverse_rotation + exprs[i] * target_vars[0]) else: raise ValueError( f'Requested source coordinate index: {source_index} is invalid.' ) return rotation
def rotate_to_target(linear: Dict[Tuple[str, int], ITeg], source_vars: List[TegVar]) -> List[ITeg]: """Generates the set of expressions for the rotated target variables. See Appendix A for details. """ rotation = [] num_vars = len(source_vars) exprs = [linear[(s_var.name, s_var.uid)] for s_var in source_vars] for target_index in range(num_vars): if target_index == 0: rotation.append( sum(exprs[i] * source_vars[i] for i in range(num_vars))) elif target_index < len(linear): i = target_index rotation_expr = sum( ((Const(1) if i == j else Const(0)) - (exprs[i] * exprs[j]) / (1 + exprs[0])) * source_vars[j] for j in range(1, num_vars)) rotation.append(-exprs[i] * source_vars[0] + rotation_expr) else: raise ValueError( f'Requested target coordinate index: {target_index} is out of bounds.' ) return rotation
def outer_fn(e, ctx): ctx['has_expr'] = any(ctx['has_exprs']) # Check if we need to handle other such cases. assert not (ctx['has_expr'] and isinstance(e, SmoothFunc)),\ f'expr is contained in a non-linear function {type(e)}' if isinstance(e, Add): if ctx['has_expr']: ctx['expr'] = sum([ child for child, has_expr in zip(ctx['exprs'], ctx['has_exprs']) if has_expr ]) return ctx['expr'], ctx else: ctx['expr'] = e return e, ctx elif isinstance(e, Tup): if ctx['has_expr']: ctx['expr'] = Tup(*[ ctx['exprs'][idx] if has_expr else Const(0) for idx, has_expr in enumerate(ctx['has_exprs']) ]) return ctx['expr'], ctx elif isinstance(e, IfElse): if ctx['has_expr']: ctx['expr'] = IfElse( e.cond, ctx['exprs'][1] if ctx['has_exprs'][1] else Const(0), ctx['exprs'][2] if ctx['has_exprs'][2] else Const(0)) return ctx, e elif isinstance(e, LetIn): if any(ctx['has_exprs'][1:]): # Let expressions contain exprs. new_exprs = [ let_var for let_var, has_expr in zip( e.new_vars, ctx['has_exprs'][1:]) if has_expr ] # Recursively split the body with the new expressions. s_expr = split_exprs(new_exprs, ctx['let_body']) let_body = (s_expr if s_expr else Const(0)) +\ (e.expr if ctx['has_exprs'][0] else Const(0)) try: vs, es = zip(*[(v, e) for v, e in zip(e.new_vars, e.new_exprs) if v in let_body]) ctx['expr'] = LetIn(vs, es, let_body) except ValueError: # No need for a let expr. ctx['expr'] = let_body return ctx['expr'], ctx ctx['expr'] = e return ctx['expr'], ctx
def eliminate_bimaps(expr: ITeg): # find top_level bimap # check if bimap contains delta. # If yes: lift using split_instance() if bimap is not already linear in tree # reduce using reparameterize() # If no: convert to let expression top_level_bimap = top_level_instance_of(expr, lambda a: isinstance(a, BiMap)) if top_level_bimap is None: return expr top_level_delta_of_bimap = top_level_instance_of( top_level_bimap, lambda a: isinstance(a, Delta)) if top_level_delta_of_bimap is None: let_expr = LetIn(top_level_bimap.targets, top_level_bimap.target_exprs, top_level_bimap.expr) return eliminate_bimaps( substitute_instance(expr, top_level_bimap, let_expr)) else: linear_expr = split_instance(top_level_bimap, expr) old_tree = substitute_instance(expr, top_level_bimap, Const(0)) new_tree = tree_copy(reparameterize(top_level_bimap, linear_expr)) e = old_tree + new_tree return eliminate_bimaps(e)
def check_single_linear_var(expr, not_ctx=set()): """Checks that expr contains a single variable with coefficient 1. """ affine_list = extract_coefficients_from_affine(expr, {(var.name, var.uid) for var in not_ctx}) linear_list = remove_constant_coeff(affine_list) return (len(linear_list) == 1 # Single variable and Const(1) in linear_list.values()) # with a coefficient of 1
def combine_poly_sets(poly_lists: List[Dict[Tuple[Tuple[Var, int]], ITeg]], op): # Combine polynomial sets. Assumes the list satisfies affine properties. combined_set = {} if op == operator.mul: # Cartesian product. Produce every variable combination. poly_products = product( *[poly_list.items() for poly_list in poly_lists]) for poly_product in poly_products: combined_variable = [var_expr[0] for var_expr in poly_product] k = [var_expr[1] for var_expr in poly_product] combined_expr = reduce(operator.mul, k) # Reduce combined variables to primitive variables. primitive_variable = combine_variables(combined_variable) combined_set[primitive_variable] = combined_expr elif op == operator.add: for poly_list in poly_lists: for variable, expr in poly_list.items(): combined_set[variable] = combined_set.get(variable, Const(0)) + expr else: raise ValueError('Operation not supported') return combined_set
def get_poly_term(expr_list, multiplicities): new_var = ((var, multiplicity) for var, multiplicity in multiplicities.items() if var is not CONST_VAR and multiplicity != 0) new_var = tuple(sorted(new_var, key=lambda a: a[0].uid)) new_var = new_var if new_var else ((CONST_VAR, 1), ) return expr_list.get(new_var, Const(0))
def translate(in_vars: Tuple[TegVar], translate: Tuple[ITeg]): out_vars = [TegVar(f'{in_var.name}_t') for in_var in in_vars] return (partial(BiMap, targets=out_vars, target_exprs=[in_var + t for (in_var, t) in zip(in_vars, translate)], sources=in_vars, source_exprs=[out_var - t for (out_var, t) in zip(out_vars, translate)], inv_jacobian=Const(1), target_upper_bounds=[in_var.ub() + t for (in_var, t) in zip(in_vars, translate)], target_lower_bounds=[in_var.lb() + t for (in_var, t) in zip(in_vars, translate)]), out_vars)
def extract_coefficients_from_polynomial( expr: ITeg, not_ctx: Set[Tuple[str, int]]) -> Dict[Set[Tuple[Var, int]], ITeg]: if isinstance(expr, Mul): children_coeffs = [ extract_coefficients_from_polynomial(child, not_ctx) for child in expr.children ] return combine_poly_sets(children_coeffs, op=operator.mul) elif isinstance(expr, Add): children_coeffs = [ extract_coefficients_from_polynomial(child, not_ctx) for child in expr.children ] return combine_poly_sets(children_coeffs, op=operator.add) elif isinstance(expr, TegVar) and (expr.name, expr.uid) in not_ctx: return {((expr, 1), ): Const(1)} elif is_expr_parametric(expr, not_ctx): return {((CONST_VAR, 1), ): expr} else: return {((CONST_VAR, 1), ): Const(0)}
def derivs_for_single_outval( expr: ITeg, single_outval: Const, i: Optional[int] = None, output_list: Optional[List[Var]] = None, args: Dict[str, Any] = None) -> Tuple[List[Var], ITeg]: partial_deriv_map = defaultdict(lambda: Const(0)) # After deriv_transform, expr will have unbound infinitesimals for name_uid, e in reverse_deriv_transform(expr, single_outval, set(), {}, args): partial_deriv_map[name_uid] += e # Introduce fresh variables for each partial derivative uids = [var_uid for var_name, var_uid in partial_deriv_map.keys()] new_vars = [ Var(var_name) for var_name, var_uid in partial_deriv_map.keys() ] new_vals = [*partial_deriv_map.values()] if output_list is not None: # Return requested list of outputs. var_map = { uid: (var, val) for uid, var, val in zip(uids, new_vars, new_vals) } new_vars, new_vals = zip(*[ var_map.get(var.uid, (Var(f'd{var.name}'), Const(0))) for var in output_list ]) else: # Return a list sorted in the order the variables were defined. sorted_list = list(zip(uids, new_vars, new_vals)) sorted_list.sort(key=lambda a: a[0]) _, new_vars, new_vals = list(zip(*sorted_list)) assert len( new_vals) > 0, 'There must be variables to compute derivatives. ' return new_vars, (Tup(*new_vals) if len(new_vars) > 1 else new_vals[0])
def outer_fn(e, ctx): if isinstance(e, Delta) and (ctx['search_expr'] is e): assert is_delta_normal( e ), f'Delta {e} is not in normal form. Call normalize_delta() first' if e.expr not in ctx['upper_tegvars']: return Const(0), ctx else: return Const(1), { **ctx, 'eliminate_tegs': { **ctx['eliminate_tegs'], e.expr: Const(0) } } elif isinstance(e, Teg): if e.dvar in ctx['eliminate_tegs']: value = ctx['eliminate_tegs'][e.dvar] bounds_check = (e.lower < value) & (e.upper > value) return (LetIn([e.dvar], [value], IfElse(bounds_check, e.body, Const(0))), ctx) return e, ctx
def extract_coefficients_from_affine( expr: ITeg, not_ctx: Set[Union[Var, Tuple]]) -> Dict[Tuple[str, int], ITeg]: """Canonicalizes an affine expression to a mapping from variables to coefficients with a constant term. """ if isinstance(expr, Mul): children_coeffs = [ extract_coefficients_from_affine(child, not_ctx) for child in expr.children ] return combine_affine_sets(children_coeffs, op=operator.mul) elif isinstance(expr, Add): children_coeffs = [ extract_coefficients_from_affine(child, not_ctx) for child in expr.children ] return combine_affine_sets(children_coeffs, op=operator.add) elif isinstance(expr, TegVar) and (expr.name, expr.uid) in not_ctx: return {(expr.name, expr.uid): Const(1)} elif is_expr_parametric(expr, not_ctx): return {('__const__', -1): expr} else: return {('__const__', -1): Const(0)}
def rewrite(delta, not_ctx=set()): """Define a change of varibles so that Delta(x + c) becomes Delta(y). """ affine_list = extract_coefficients_from_affine(delta.expr, {(var.name, var.uid) for var in not_ctx}) constant = affine_list.get(('__const__', -1), Const(0)) only_var = [(name, uid) for name, uid in affine_list.keys() if uid != -1] assert len( only_var ) == 1, f'Only one tegvar can be included in the affine expression. {only_var}' var_name, var_uid = only_var[0] source_var = TegVar(name=var_name, uid=var_uid) target_var = TegVar(name=f'{var_name}_') return BiMap(expr=Delta(target_var), sources=[source_var], source_exprs=[target_var - constant], targets=[target_var], target_exprs=[source_var + constant], inv_jacobian=Const(1), target_lower_bounds=[source_var.lower_bound() + constant], target_upper_bounds=[source_var.upper_bound() + constant])
def transfer_bounds_general(expr: BiMap, source_lower: Dict[TegVar, ITeg], source_upper: Dict[TegVar, ITeg]): """Implements a derivative-based pessimistic bounds computation for continuous monotonic maps. """ lb_lets = {} ub_lets = {} for tegvar in source_lower: deriv_expr = fwd_deriv(expr, {tegvar: Const(1)}) lb_lets[tegvar] = (IfElse(deriv_expr > 0, source_upper[tegvar], source_lower[tegvar])) ub_lets[tegvar] = (IfElse(deriv_expr > 0, source_lower[tegvar], source_upper[tegvar])) return LetIn(lb_lets.keys(), lb_lets.values(), expr), LetIn(ub_lets.keys(), ub_lets.values(), expr)
def rewrite(delta: Delta, not_ctx: Optional[Set] = set()) -> ITeg: """Rotates an affine discontinuity so that it's axis-aligned (e.g. ax + by + c -> z + d). """ not_ctx = set() if not_ctx is None else not_ctx # Canonicalize affine expression into a map {var: coeff} raw_affine_set = extract_coefficients_from_affine( delta.expr, {(var.name, var.uid) for var in not_ctx}) # Introduce a constant term if there isn't one if ('__const__', -1) not in raw_affine_set: raw_affine_set[('__const__', -1)] = Const(0) # Extract source variables (in order) source_vars = [ TegVar(name=name, uid=uid) for name, uid in var_list(remove_constant_coeff(raw_affine_set)) ] # Create rotated (target) variables target_vars = [TegVar(name=f'{var.name}_') for var in source_vars] # TODO: Currently, do not handle degeneracy at -1 affine_set, flip_condition = negate_degenerate_coeffs( raw_affine_set, source_vars) linear_set = remove_constant_coeff(affine_set) normalized_set, normalization_var, normalization_expr = normalize_linear( linear_set) dvar = target_vars[0] expr_for_dvar = -constant_coeff(affine_set) * normalization_var source_exprs = rotate_to_source(normalized_set, target_vars, source_vars) target_exprs = rotate_to_target(normalized_set, source_vars) lower_bounds, upper_bounds = bounds_of(normalized_set, source_vars) return LetIn([normalization_var], [normalization_expr], BiMap(expr=Delta(dvar - expr_for_dvar), sources=source_vars, source_exprs=source_exprs, targets=target_vars, target_exprs=target_exprs, inv_jacobian=normalization_var, target_lower_bounds=lower_bounds, target_upper_bounds=upper_bounds))
def rotate_2d(x, y, theta): x_ = TegVar('x_') y_ = TegVar('y_') return (partial(BiMap, targets=[x_, y_], target_exprs=[x * Cos(theta) + y * Sin(theta), -x * Sin(theta) + y * Cos(theta)], sources=[x, y], source_exprs=[x_ * Cos(theta) - y_ * Sin(theta), x_ * Sin(theta) + y_ * Cos(theta)], inv_jacobian=Const(1), target_lower_bounds=[Cos(theta) * IfElse(Cos(theta) > 0, x.lb(), x.ub()) + Sin(theta) * IfElse(Sin(theta) > 0, y.lb(), y.ub()), -Sin(theta) * IfElse(Sin(theta) > 0, x.ub(), x.lb()) + Cos(theta) * IfElse(Cos(theta) > 0, y.lb(), y.ub())], target_upper_bounds=[Cos(theta) * IfElse(Cos(theta) > 0, x.ub(), x.lb()) + Sin(theta) * IfElse(Sin(theta) > 0, y.ub(), y.lb()), -Sin(theta) * IfElse(Sin(theta) > 0, x.lb(), x.ub()) + Cos(theta) * IfElse(Cos(theta) > 0, y.ub(), y.lb())]), [x_, y_])
def bounds_of(linear: Dict[Tuple[str, int], ITeg], source_vars: List[TegVar]) -> List[ITeg]: """Generates the bounds of integration after rotation (i.e., it's the bounds transfer function). """ lower_bounds, upper_bounds = [], [] num_vars = len(source_vars) exprs = [linear[(s_var.name, s_var.uid)] for s_var in source_vars] for target_index in range(num_vars): if target_index == 0: lower = sum(exprs[i] * IfElse(exprs[i] > 0, source_vars[i].lower_bound(), source_vars[i].upper_bound()) for i in range(num_vars)) upper = sum(exprs[i] * IfElse(exprs[i] > 0, source_vars[i].upper_bound(), source_vars[i].lower_bound()) for i in range(num_vars)) elif target_index < len(linear): def coeff(u, v): if v == 0: return -exprs[u] else: return ((Const(1) if u == v else Const(0)) - (exprs[u] * exprs[v]) / (Const(1) + exprs[0])) i = target_index lower = upper = Const(0) for j in range(num_vars): placeholder_lb = source_vars[j].lower_bound() placeholder_ub = source_vars[j].upper_bound() lower += coeff(i, j) * IfElse( coeff(i, j) > 0, placeholder_lb, placeholder_ub) upper += coeff(i, j) * IfElse( coeff(i, j) > 0, placeholder_ub, placeholder_lb) else: raise ValueError( f'Requested target coordinate index: {target_index} is out of bounds.' ) lower_bounds.append(lower) upper_bounds.append(upper) return lower_bounds, upper_bounds
def outer_fn(e, ctx): if isinstance(e, Delta): # print(ctx['upper_depvars']) depvars = list(tegvar for tegvar in (ctx['upper_depvars'] - ctx['upper_tegvars']) if tegvar in e) assert not depvars,\ f'Delta expression {e} is not explicitly affine: ({depvars}) '\ f'is/are dependent on one or more of {ctx["upper_tegvars"]} '\ f'through one-way let expressions. Use bijective maps (BiMap) instead' if (not any([ k in ctx['upper_tegvars'] for k in ctx['lower_tegvars'] ])) or (not ctx['lower_tegvars']): return Const(0), ctx else: if not is_delta_normal(e): can_rewrites = [ handler.can_rewrite(e, set(ctx['upper_tegvars'])) for handler in HANDLERS ] assert any( can_rewrites ), f'Cannot find any handler for delta expression {e}' handler = HANDLERS[can_rewrites.index(True)] e = handler.rewrite(e, set(ctx['upper_tegvars'])) e = normalize_deltas(e) # Normalize further if necessary return e, ctx elif isinstance(e, BiMap): return e, { **ctx, 'lower_tegvars': ctx['lower_tegvars'] - set(e.targets) } elif isinstance(e, TegVar): return e, {**ctx, 'lower_tegvars': ctx['lower_tegvars'] | {e}} return e, ctx
def simplify(expr: ITeg) -> ITeg: if isinstance(expr, Var): return expr elif isinstance(expr, Add): expr1, expr2 = expr.children simple1, simple2 = simplify(expr1), simplify(expr2) if isinstance(simple1, Const) and simple1.value == 0: return simple2 if isinstance(simple2, Const) and simple2.value == 0: return simple1 if isinstance(simple1, Const) and isinstance(simple2, Const): return Const(evaluate(simple1 + simple2)) # Associative reordering. if isinstance(simple1, (Add, Const)) and isinstance(simple2, (Add, Const)): nodes1 = [ simple1, ] if isinstance(simple1, Const) else simple1.children nodes2 = [ simple2, ] if isinstance(simple2, Const) else simple2.children all_nodes = nodes1 + nodes2 assert 2 <= len( all_nodes ) <= 4, 'Unexpected number of nodes in Add-associative tree' const_nodes = [ node for node in all_nodes if isinstance(node, Const) ] other_nodes = [ node for node in all_nodes if not isinstance(node, Const) ] # No const nodes -> Reordering is pointless. if len(other_nodes) == len(all_nodes): return simple1 + simple2 # Compress const nodes. const_node = Const(evaluate(reduce(operator.add, const_nodes))) # Re-order to front. if const_node == Const(0): simplified_nodes = other_nodes else: simplified_nodes = other_nodes + [const_node] # Build tree in reverse (so const node is at top level) return reduce(operator.add, simplified_nodes) if isinstance(simple1, LetIn) and isinstance(simple2, LetIn): if simple1.new_vars == simple2.new_vars and simple1.new_exprs == simple2.new_exprs: return LetIn(new_vars=simple1.new_vars, new_exprs=simple1.new_exprs, expr=simplify(simple1.expr + simple2.expr)) else: return simple1 + simple2 if isinstance(simple1, Teg) and isinstance(simple2, Teg): if (simple1.dvar == simple2.dvar and simple1.lower == simple2.lower and simple1.upper == simple2.upper): return simplify( Teg(simple1.lower, simple1.upper, simplify(simple1.body + simple2.body), simple1.dvar)) else: return simple1 + simple2 if isinstance(simple1, IfElse) and isinstance(simple2, IfElse): if simple1.cond == simple2.cond: return IfElse(simple1.cond, simplify(simple1.if_body + simple2.if_body), simplify(simple1.else_body + simple2.else_body)) else: return simple1 + simple2 if isinstance(simple1, Mul) and isinstance(simple2, Mul): # Distribution. exprLL, exprLR = simple1.children exprRL, exprRR = simple2.children if exprLL == exprRR: return simplify(exprLL * (simplify(exprLR + exprRL))) if exprLL == exprRL: return simplify(exprLL * (simplify(exprLR + exprRR))) if exprLR == exprRL: return simplify(exprLR * (simplify(exprLL + exprRR))) if exprLR == exprRR: return simplify(exprLR * (simplify(exprLL + exprRL))) return simple1 + simple2 elif isinstance(expr, Mul): expr1, expr2 = expr.children simple1, simple2 = simplify(expr1), simplify(expr2) # 0-elimination if ((isinstance(simple1, Const) and simple1.value == 0) or (isinstance(simple2, Const) and hasattr(simple2, 'value') and simple2.value == 0)): return Const(0) # Multiplicative inverse. if isinstance(simple1, Const) and simple1.value == 1.0: return simple2 if isinstance(simple2, Const) and simple2.value == 1.0: return simple1 # Local constant compression. if isinstance(simple1, Const) and isinstance(simple2, Const): return Const(evaluate(simple1 * simple2)) # Associative reordering. if isinstance(simple1, (Mul, Const)) and isinstance(simple2, (Mul, Const)): nodes1 = [simple1] if isinstance(simple1, Const) else simple1.children nodes2 = [simple2] if isinstance(simple2, Const) else simple2.children all_nodes = nodes1 + nodes2 assert 2 <= len( all_nodes ) <= 4, 'Unexpected number of nodes in Mul-associative tree' const_nodes = [ node for node in all_nodes if isinstance(node, Const) ] other_nodes = [ node for node in all_nodes if not isinstance(node, Const) ] # No const nodes -> Reordering is pointless. if len(other_nodes) == len(all_nodes): return simple1 * simple2 # Compress const nodes. const_node = Const(evaluate(reduce(operator.mul, const_nodes))) # Re-order to front. if not (const_node == Const(1)): simplified_nodes = other_nodes + [const_node] else: simplified_nodes = other_nodes # Build tree in reverse (so const node is at top level) return reduce(operator.mul, simplified_nodes) return simple1 * simple2 elif isinstance(expr, Invert): simple = simplify(expr.child) if isinstance(simple, Const): return Const(evaluate(Invert(simple))) return Invert(simple) elif isinstance(expr, SmoothFunc): simple = simplify(expr.expr) if isinstance(simple, Const): return Const(evaluate(type(expr)(simple))) return type(expr)(simplify(expr.expr)) elif isinstance(expr, IfElse): cond, if_body, else_body = simplify(expr.cond), simplify( expr.if_body), simplify(expr.else_body) if (isinstance(if_body, Const) and isinstance(else_body, Const) and if_body.value == 0 and else_body.value == 0): return if_body if cond == true: return if_body if cond == false: return else_body return IfElse(cond, if_body, else_body) elif isinstance(expr, Teg): body = simplify(expr.body) if isinstance(body, Const) and hasattr(body, 'value') and body.value == 0: return Const(0) return Teg(simplify(expr.lower), simplify(expr.upper), body, expr.dvar) elif isinstance(expr, Tup): return Tup(*(simplify(child) for child in expr)) elif isinstance(expr, LetIn): simplified_exprs = Tup(*(simplify(e) for e in expr.new_exprs)) child_expr = simplify(expr.expr) vars_list = expr.new_vars for s_var, s_expr in zip(vars_list, simplified_exprs): if isinstance(s_expr, Const): child_expr = substitute(child_expr, s_var, s_expr) non_const_bindings = [ (s_var, s_expr) for s_var, s_expr in zip(vars_list, simplified_exprs) if not isinstance(s_expr, Const) ] child_expr = simplify(child_expr) if non_const_bindings: non_const_vars, non_const_exprs = zip(*list(non_const_bindings)) return (LetIn(non_const_vars, non_const_exprs, child_expr) if not isinstance(child_expr, Const) else child_expr) else: return child_expr elif isinstance(expr, BiMap): simplified_target_exprs = list(simplify(e) for e in expr.target_exprs) simplified_source_exprs = list(simplify(e) for e in expr.source_exprs) simplified_ubs = list(simplify(e) for e in expr.target_upper_bounds) simplified_lbs = list(simplify(e) for e in expr.target_lower_bounds) child_expr = simplify(expr.expr) return BiMap(expr=child_expr, targets=expr.targets, target_exprs=simplified_target_exprs, sources=expr.sources, source_exprs=simplified_source_exprs, inv_jacobian=simplify(expr.inv_jacobian), target_lower_bounds=simplified_lbs, target_upper_bounds=simplified_ubs) elif isinstance(expr, Delta): return Delta(simplify(expr.expr)) elif {'FwdDeriv', 'RevDeriv'} & {t.__name__ for t in type(expr).__mro__}: return simplify(expr.__getattribute__('deriv_expr')) elif isinstance(expr, Bool): left_expr, right_expr = simplify(expr.left_expr), simplify( expr.right_expr) if isinstance(left_expr, Const) and isinstance(right_expr, Const): return false if evaluate(Bool(left_expr, right_expr)) == 0.0 else true return Bool(left_expr, right_expr) elif isinstance(expr, And): left_expr, right_expr = simplify(expr.left_expr), simplify( expr.right_expr) if left_expr == true: return right_expr if right_expr == true: return left_expr if left_expr == false or right_expr == false: return false return And(left_expr, right_expr) elif isinstance(expr, Or): left_expr, right_expr = simplify(expr.left_expr), simplify( expr.right_expr) if left_expr == false: return right_expr if right_expr == false: return left_expr if left_expr == true or right_expr == true: return true return Or(left_expr, right_expr) else: raise ValueError( f'The type of the expr "{type(expr)}" does not have a supported simplify rule' )
t3, t4 = Var('t3'), Var('t4') scale_map, (x_s, y_s) = scale([x, y], [t1, t2]) translate_map, (x_st, y_st) = translate([x_s, y_s], [t3, t4]) # Area of a unit circle. bindings = {t1: 1, t2: 1, t3: 0, t4: 0, t: 0.25} # Derivative of threshold only. integral = Teg( x_lb, x_ub, Teg(y_lb, y_ub, scale_map(translate_map(IfElse(x_st * y_st > t, 1, 0))), y ), x ) d_vars, dt_exprs = reverse_deriv(integral, Tup(Const(1)), output_list=[t, t1, t2, t3, t4]) integral = reduce_to_base(integral) image = render_image(integral, variables=((x_lb, x_ub), (y_lb, y_ub)), bindings=bindings, bounds=((-1, 1), (-1, 1)), res=(args.res_x, args.res_y), ) save_image(np.abs(image), filename=f'{args.testname}.png') for d_var, dt_expr in zip(d_vars, dt_exprs): image = render_image(reduce_to_base(dt_expr), variables=((x_lb, x_ub), (y_lb, y_ub)), bindings=bindings, bounds=((-1, 1), (-1, 1)),
def eliminate_deltas(expr: ITeg): # eliminate deltas through let expressions # remove the corresponding integral. # (error if corresponding integral does not exist) def inner_fn(e, ctx): if isinstance(e, Teg): return e, { 'is_expr': ctx['search_expr'] is e, 'upper_tegvars': ctx['upper_tegvars'] | {e.dvar}, 'search_expr': ctx['search_expr'] } return e, { 'is_expr': False, 'upper_tegvars': ctx['upper_tegvars'], 'search_expr': ctx['search_expr'] } def outer_fn(e, ctx): if isinstance(e, Delta) and (ctx['search_expr'] is e): assert is_delta_normal( e ), f'Delta {e} is not in normal form. Call normalize_delta() first' if e.expr not in ctx['upper_tegvars']: return Const(0), ctx else: return Const(1), { **ctx, 'eliminate_tegs': { **ctx['eliminate_tegs'], e.expr: Const(0) } } elif isinstance(e, Teg): if e.dvar in ctx['eliminate_tegs']: value = ctx['eliminate_tegs'][e.dvar] bounds_check = (e.lower < value) & (e.upper > value) return (LetIn([e.dvar], [value], IfElse(bounds_check, e.body, Const(0))), ctx) return e, ctx def context_combine(contexts, ctx): return { 'lower_tegvars': reduce(lambda a, b: a | b, [ctx['lower_tegvars'] for ctx in contexts], set()), 'upper_tegvars': ctx['upper_tegvars'], 'eliminate_tegs': reduce(lambda a, b: { **a, **b }, [ctx['eliminate_tegs'] for ctx in contexts], {}), 'search_expr': ctx['search_expr'] } def eliminate_delta(delta, t_expr): return base_pass(t_expr, { 'upper_tegvars': set(), 'search_expr': delta }, inner_fn, outer_fn, context_combine)[0] top_level_delta = top_level_instance_of(expr, lambda a: isinstance(a, Delta)) if top_level_delta is None: return expr else: linear_expr = split_instance(top_level_delta, expr) old_tree = substitute_instance(expr, top_level_delta, Const(0)) new_tree = tree_copy(eliminate_delta(top_level_delta, linear_expr)) return eliminate_deltas(old_tree + new_tree)
def reverse_deriv_transform( expr: ITeg, out_deriv_vals: Tuple, not_ctx: Set[Tuple[str, int]], deps: Dict[TegVar, Set[Var]], args: Dict[str, Any]) -> Iterable[Tuple[Tuple[str, int], ITeg]]: if isinstance(expr, TegVar): if (((expr.name, expr.uid) not in not_ctx) or {(v.name, v.uid) for v in extend_dependencies({expr}, deps)} - not_ctx): yield ((f'd{expr.name}', expr.uid), out_deriv_vals) elif isinstance(expr, (Const, Delta)): pass elif isinstance(expr, Var): if (expr.name, expr.uid) not in not_ctx: yield ((f'd{expr.name}', expr.uid), out_deriv_vals) elif isinstance(expr, Add): left, right = expr.children # yield from reverse_deriv_transform(left, out_deriv_vals, not_ctx, teg_list) # yield from reverse_deriv_transform(right, out_deriv_vals, not_ctx, teg_list) left_list = list( reverse_deriv_transform(left, Const(1), not_ctx, deps, args)) right_list = list( reverse_deriv_transform(right, Const(1), not_ctx, deps, args)) yield from merge(left_list, right_list, out_deriv_vals) elif isinstance(expr, Mul): left, right = expr.children # yield from reverse_deriv_transform(left, out_deriv_vals * right, not_ctx, deps) # yield from reverse_deriv_transform(right, out_deriv_vals * left, not_ctx, deps) left_list = list( reverse_deriv_transform(left, right, not_ctx, deps, args)) right_list = list( reverse_deriv_transform(right, left, not_ctx, deps, args)) yield from merge(left_list, right_list, out_deriv_vals) elif isinstance(expr, Invert): child = expr.child yield from reverse_deriv_transform(child, -out_deriv_vals * expr * expr, not_ctx, deps, args) elif isinstance(expr, SmoothFunc): child = expr.expr yield from reverse_deriv_transform( child, expr.rev_deriv(out_deriv_expr=out_deriv_vals), not_ctx, deps, args) elif isinstance(expr, IfElse): derivs_if = reverse_deriv_transform(expr.if_body, Const(1), not_ctx, deps, args) derivs_else = reverse_deriv_transform(expr.else_body, Const(1), not_ctx, deps, args) yield from ((name_uid, out_deriv_vals * IfElse(expr.cond, deriv_if, Const(0))) for name_uid, deriv_if in derivs_if) yield from ((name_uid, out_deriv_vals * IfElse(expr.cond, Const(0), deriv_else)) for name_uid, deriv_else in derivs_else) if not args.get('ignore_deltas', False): for boolean in primitive_booleans_in(expr.cond, not_ctx, deps): jump = substitute(expr, boolean, true) - substitute( expr, boolean, false) delta_expr = boolean.right_expr - boolean.left_expr derivs_delta_expr = reverse_deriv_transform( delta_expr, Const(1), not_ctx, deps, args) yield from ( (name_uid, out_deriv_vals * deriv_delta_expr * jump * Delta(delta_expr)) for name_uid, deriv_delta_expr in derivs_delta_expr) elif isinstance(expr, Teg): not_ctx.discard((expr.dvar.name, expr.dvar.uid)) # Apply Leibniz rule directly for moving boundaries if not args.get('ignore_bounds', False): lower_derivs = reverse_deriv_transform(expr.lower, out_deriv_vals, not_ctx, deps, args) upper_derivs = reverse_deriv_transform(expr.upper, out_deriv_vals, not_ctx, deps, args) yield from ((name_uid, upper_deriv * substitute(expr.body, expr.dvar, expr.upper)) for name_uid, upper_deriv in upper_derivs) yield from ((name_uid, -lower_deriv * substitute(expr.body, expr.dvar, expr.lower)) for name_uid, lower_deriv in lower_derivs) not_ctx.add((expr.dvar.name, expr.dvar.uid)) deriv_body_traces = reverse_deriv_transform(expr.body, Const(1), not_ctx, deps, args) yield from ((name_uid, out_deriv_vals * Teg(expr.lower, expr.upper, deriv_body, expr.dvar)) for name_uid, deriv_body in deriv_body_traces) elif isinstance(expr, Tup): yield [ reverse_deriv_transform(child, out_deriv_vals, not_ctx, deps, args) for child in expr ] elif isinstance(expr, LetIn): # Include derivatives of each expression to the let body dnew_vars, body_derivs = set(), {} for var, e in zip(expr.new_vars, expr.new_exprs): # print(not_ctx) # print(var, e) if any( Var(name=ctx_name, uid=ctx_uid) in e for ctx_name, ctx_uid in not_ctx): # Add dependent variables. assert isinstance(var, TegVar), f'{var} is dependent on TegVar(s):'\ f'({[ctx_var for ctx_var in not_ctx if ctx_var in e]}).'\ f'{var} must also be declared as a TegVar and not a Var' # print(not_ctx) not_ctx = not_ctx | {(var.name, var.uid)} # print(var) if var not in expr.expr: # print('Not in expression') continue # print('In expression') dname = f'd{var.name}' dnew_vars.add((dname, var.uid)) body_derivs[(dname, var.uid)] = list( reverse_deriv_transform(e, Const(1), not_ctx, deps, args)) # Thread through derivatives of each subexpression for (name, uid), dname_expr in reverse_deriv_transform( expr.expr, out_deriv_vals, not_ctx, deps, args): dvar_with_ctx = LetIn(expr.new_vars, expr.new_exprs, dname_expr) if (name, uid) in dnew_vars: yield from ((n, d * dvar_with_ctx) for n, d in body_derivs[(name, uid)]) else: yield ((name, uid), dvar_with_ctx) elif isinstance(expr, BiMap): # Include derivatives of each expression to the let body dnew_vars, body_derivs = set(), {} new_deps = {} for var, e in zip(expr.targets, expr.target_exprs): if any( Var(name=ctx_name, uid=ctx_uid) in e for ctx_name, ctx_uid in not_ctx): # Add dependent variables. assert isinstance(var, TegVar), f'{var} is dependent on TegVar(s):'\ f'({[ctx_var for ctx_var in not_ctx if ctx_var in e]}).'\ f'{var} must also be declared as a TegVar and not a Var' not_ctx = not_ctx | {(var.name, var.uid)} if var in expr.expr: new_deps[var] = extract_vars(e) dname = f'd{var.name}' dnew_vars.add((dname, var.uid)) body_derivs[(dname, var.uid)] = list( reverse_deriv_transform(e, Const(1), not_ctx, deps, args)) deps = {**deps, **new_deps} # Thread through derivatives of each subexpression for (name, uid), dname_expr in reverse_deriv_transform( expr.expr, out_deriv_vals, not_ctx, deps, args): dvar_with_ctx = BiMap(dname_expr, expr.targets, expr.target_exprs, expr.sources, expr.source_exprs, inv_jacobian=expr.inv_jacobian, target_lower_bounds=expr.target_lower_bounds, target_upper_bounds=expr.target_upper_bounds) if (name, uid) in dnew_vars: yield from ((n, d * dvar_with_ctx) for n, d in body_derivs[(name, uid)]) else: yield ((name, uid), dvar_with_ctx) else: raise ValueError( f'The type of the expr "{type(expr)}" does not have a supported derivative.' )
def teg_smoothstep(x): return IfElse(x > 0, IfElse(x < 1, 3 * Sqr(x) - 2 * Sqr(x) * x, Const(1)), Const(0))
def reverse_deriv(expr: ITeg, out_deriv_vals: Tup = None, output_list: Optional[List[Var]] = None, args: Dict[str, Any] = None) -> ITeg: """Computes the derivative of a given expression. Args: expr: The expression to compute the total derivative of. out_deriv_vals: A mapping from variable names to the values of corresponding infinitesimals. args: Additional mappings for specifying alternative behavior such as 'ignore_deltas' and 'ignore_bounds'. Returns: ITeg: The reverse derivative expression in the extended language. """ if out_deriv_vals is None: out_deriv_vals = Tup(Const(1)) if args is None: args = {} def derivs_for_single_outval( expr: ITeg, single_outval: Const, i: Optional[int] = None, output_list: Optional[List[Var]] = None, args: Dict[str, Any] = None) -> Tuple[List[Var], ITeg]: partial_deriv_map = defaultdict(lambda: Const(0)) # After deriv_transform, expr will have unbound infinitesimals for name_uid, e in reverse_deriv_transform(expr, single_outval, set(), {}, args): partial_deriv_map[name_uid] += e # Introduce fresh variables for each partial derivative uids = [var_uid for var_name, var_uid in partial_deriv_map.keys()] new_vars = [ Var(var_name) for var_name, var_uid in partial_deriv_map.keys() ] new_vals = [*partial_deriv_map.values()] if output_list is not None: # Return requested list of outputs. var_map = { uid: (var, val) for uid, var, val in zip(uids, new_vars, new_vals) } new_vars, new_vals = zip(*[ var_map.get(var.uid, (Var(f'd{var.name}'), Const(0))) for var in output_list ]) else: # Return a list sorted in the order the variables were defined. sorted_list = list(zip(uids, new_vars, new_vals)) sorted_list.sort(key=lambda a: a[0]) _, new_vars, new_vals = list(zip(*sorted_list)) assert len( new_vals) > 0, 'There must be variables to compute derivatives. ' return new_vars, (Tup(*new_vals) if len(new_vars) > 1 else new_vals[0]) if len(out_deriv_vals) == 1: single_outval = out_deriv_vals.children[0] derivs = derivs_for_single_outval(expr, single_outval, 0, output_list=output_list, args=args) else: assert len(out_deriv_vals) == len(expr), \ f'Expected out_deriv to have "{len(expr)}" values, but got "{len(out_deriv_vals)}" values.' derivs = ( derivs_for_single_outval(e, single_outval, i, output_list=output_list, args=args) for i, (e, single_outval) in enumerate(zip(expr, out_deriv_vals))) derivs = Tup(*derivs) return derivs
from teg.lang.extended import (BiMap) from teg.math import (Sin, Cos, Sqrt, ATan2, Sqr) import numpy as np def teg_max(a, b): return IfElse(a > b, a, b) def teg_min(a, b): return IfElse(a > b, b, a) TEG_NEGATIVE_PI = Const(-np.pi) TEG_PI = Const(np.pi) TEG_2_PI = Const(2 * np.pi) def polar_2d_map(expr, x, y, r): """ Create a polar 2D map with x=0, y=0 as center and negative y axis as 0 & 2PI """ theta = TegVar('theta') distance_to_origin = Sqrt( Sqr((y.lb() + y.ub()) / 2) + Sqr((x.lb() + x.ub()) / 2)) box_radius = Sqrt(Sqr((y.ub() - y.lb()) / 2) + Sqr((x.ub() - x.lb()) / 2)) # Manual interval arithmetic for conservative polar bounds.
def constant_coeff(affine: Dict[Tuple[str, int], ITeg]): """Extract the constant coefficient if it exists, otherwise, return 0. """ return affine[('__const__', -1)] if ('__const__', -1) in affine else Const(0)
def coeff(u, v): if v == 0: return -exprs[u] else: return ((Const(1) if u == v else Const(0)) - (exprs[u] * exprs[v]) / (Const(1) + exprs[0]))
def fwd_deriv_transform( expr: ITeg, ctx: Dict[Tuple[str, int], ITeg], not_ctx: Set[Tuple[str, int]], deps: Dict[TegVar, Set[Var]] ) -> Tuple[ITeg, Dict[Tuple[str, int], str], Set[Tuple[str, int]]]: """Compute the source-to-source foward derivative of the given expression.""" if isinstance(expr, TegVar): if (((expr.name, expr.uid) not in not_ctx or {(v.name, v.uid) for v in extend_dependencies({expr}, deps)} - not_ctx) and (expr.name, expr.uid) in ctx): expr = ctx[(expr.name, expr.uid)] else: expr = Const(0) elif isinstance(expr, (Const, Placeholder, Delta)): expr = Const(0) elif isinstance(expr, Var): if (expr.name, expr.uid) not in not_ctx and (expr.name, expr.uid) in ctx: expr = ctx[(expr.name, expr.uid)] else: expr = Const(0) elif isinstance(expr, SmoothFunc): in_deriv_expr, ctx, not_ctx, deps = fwd_deriv_transform( expr.expr, ctx, not_ctx, deps) deriv_expr = expr.fwd_deriv(in_deriv_expr=in_deriv_expr) expr = deriv_expr elif isinstance(expr, Add): sum_of_derivs = Const(0) for child in expr.children: deriv_child, ctx, not_ctx, deps = fwd_deriv_transform( child, ctx, not_ctx, deps) sum_of_derivs += deriv_child expr = sum_of_derivs elif isinstance(expr, Mul): # NOTE: Consider n-ary multiplication. assert len( expr.children ) == 2, 'fwd_deriv only supports binary multiplication not n-ary.' expr1, expr2 = [child for child in expr.children] (deriv_expr1, ctx1, not_ctx1, _) = fwd_deriv_transform(expr1, ctx, not_ctx, deps) (deriv_expr2, ctx2, not_ctx2, _) = fwd_deriv_transform(expr2, ctx, not_ctx, deps) expr = expr1 * deriv_expr2 + expr2 * deriv_expr1 ctx = {**ctx1, **ctx2} not_ctx = not_ctx1 | not_ctx2 elif isinstance(expr, Invert): deriv_expr, ctx, not_ctx, deps = fwd_deriv_transform( expr.child, ctx, not_ctx, deps) expr = -expr * expr * deriv_expr elif isinstance(expr, IfElse): if_body, ctx, not_ctx1, _ = fwd_deriv_transform( expr.if_body, ctx, not_ctx, deps) else_body, ctx, not_ctx2, _ = fwd_deriv_transform( expr.else_body, ctx, not_ctx, deps) not_ctx = not_ctx1 | not_ctx2 deltas = Const(0) for boolean in primitive_booleans_in(expr.cond, not_ctx, deps): jump = substitute(expr, boolean, true) - substitute( expr, boolean, false) delta_expr = boolean.right_expr - boolean.left_expr delta_deriv, ctx, _ignore_not_ctx, _ = fwd_deriv_transform( delta_expr, ctx, not_ctx, deps) deltas = deltas + delta_deriv * jump * Delta(delta_expr) expr = IfElse(expr.cond, if_body, else_body) + deltas elif isinstance(expr, Teg): assert expr.dvar not in ctx, f'Names of infinitesimal "{expr.dvar}" are distinct from context "{ctx}"' # In int_x f(x), the variable x is in scope for the integrand f(x) not_ctx.discard(expr.dvar.name) # Include derivative contribution from moving boundaries of integration boundary_val, new_ctx, new_not_ctx = boundary_contribution( expr, ctx, not_ctx, deps) not_ctx.add((expr.dvar.name, expr.dvar.uid)) body, ctx, not_ctx, _ = fwd_deriv_transform(expr.body, ctx, not_ctx, deps) ctx.update(new_ctx) not_ctx |= new_not_ctx expr = Teg(expr.lower, expr.upper, body, expr.dvar) + boundary_val elif isinstance(expr, Tup): new_expr_list, new_ctx, new_not_ctx = [], Ctx(), set() for child in expr: child, ctx, not_ctx, _ = fwd_deriv_transform( child, ctx, not_ctx, deps) new_expr_list.append(child) new_ctx.update(ctx) new_not_ctx |= not_ctx ctx, not_ctx = new_ctx, new_not_ctx expr = Tup(*new_expr_list) elif isinstance(expr, LetIn): # Compute derivatives of each expression and bind them to the corresponding dvar new_vars_with_derivs, new_exprs_with_derivs = list( expr.new_vars), list(expr.new_exprs) new_deps = {} for v, e in zip(expr.new_vars, expr.new_exprs): if v in expr.expr: # By not passing in the updated contexts, # we require that assignments in let expressions are independent de, ctx, not_ctx, _ = fwd_deriv_transform( e, ctx, not_ctx, deps) ctx[(v.name, v.uid)] = Var(f'd{v.name}') new_vars_with_derivs.append(ctx[(v.name, v.uid)]) new_exprs_with_derivs.append(de) new_deps[v] = extract_vars(e) deps = {**deps, **new_deps} # We want an expression in terms of f'd{var_in_let_body}' # This means that they are erroniously added to ctx, so we # remove them from ctx! dexpr, ctx, not_ctx, _ = fwd_deriv_transform(expr.expr, ctx, not_ctx, deps) [ctx.pop((c.name, c.uid), None) for c in expr.new_vars] expr = LetIn(Tup(*new_vars_with_derivs), Tup(*new_exprs_with_derivs), dexpr) elif isinstance(expr, BiMap): # TODO: is it possible to not repeat this code and make another recursive call instead? # Compute derivatives of each expression and bind them to the corresponding dvar new_vars_with_derivs, new_exprs_with_derivs = [], [] for v, e in zip(expr.targets, expr.target_exprs): if v in expr.expr: # By not passing in the updated contexts, require independence of exprs in the body of the let expression de, ctx, not_ctx, _ = fwd_deriv_transform( e, ctx, not_ctx, deps) ctx[(v.name, v.uid)] = Var(f'd{v.name}') new_vars_with_derivs.append(ctx[(v.name, v.uid)]) new_exprs_with_derivs.append(de) not_ctx = not_ctx | {(v.name, v.uid)} # We want an expression in terms of f'd{var_in_let_body}' # This means that they are erroniously added to ctx, so we # remove them from ctx! dexpr, ctx, not_ctx, _ = fwd_deriv_transform(expr.expr, ctx, not_ctx, deps) [ctx.pop((c.name, c.uid), None) for c in expr.targets] expr = LetIn( Tup(*new_vars_with_derivs), Tup(*new_exprs_with_derivs), BiMap(dexpr, targets=expr.targets, target_exprs=expr.target_exprs, sources=expr.sources, source_exprs=expr.source_exprs, inv_jacobian=expr.inv_jacobian, target_lower_bounds=expr.target_lower_bounds, target_upper_bounds=expr.target_upper_bounds)) else: raise ValueError( f'The type of the expr "{type(expr)}" does not have a supported fwd_derivative.' ) return expr, ctx, not_ctx, deps
def outer_fn(e, ctx): if isinstance(e, BiMap) and (bimap is e): if not all([k in ctx['upper_tegvars'] for k in e.sources]): # BiMap is invalid, null everything. print( f'WARNING: Attempting to map non-Teg vars {e.sources}, {ctx["upper_tegvars"]}' ) return Const(0), ctx bounds_checks = reduce( operator.and_, [(lb < dvar) & (ub > dvar) for (dvar, (lb, ub)) in ctx['source_bounds'].items()]) reparamaterized_expr = IfElse(bounds_checks, e.expr * e.inv_jacobian, Const(0)) return (reparamaterized_expr, { **ctx, 'teg_sources': list(e.sources), 'teg_targets': list(e.targets), 'let_mappings': {s: sexpr for s, sexpr in zip(e.sources, e.source_exprs)}, 'target_lower_bounds': {t: tlb for t, tlb in zip(e.targets, e.target_lower_bounds)}, 'target_upper_bounds': {t: tub for t, tub in zip(e.targets, e.target_upper_bounds)} }) elif isinstance(e, Teg): if e.dvar in ctx.get('teg_sources', {}): ctx['teg_sources'].remove(e.dvar) target_dvar = ctx['teg_targets'].pop() placeholders = { **{ f'{svar.uid}_ub': upper for svar, (lower, upper) in ctx['source_bounds'].items( ) }, **{ f'{svar.uid}_lb': lower for svar, (lower, upper) in ctx['source_bounds'].items( ) } } target_lower_bounds = resolve_placeholders( ctx['target_lower_bounds'][target_dvar], placeholders) target_upper_bounds = resolve_placeholders( ctx['target_upper_bounds'][target_dvar], placeholders) # Add new teg to list. ctx['new_tegs'] = [ *ctx.get('new_tegs', []), (target_dvar, (target_lower_bounds, target_upper_bounds)) ] # Remove old teg. e = e.body if len(ctx['teg_sources']) == 0: # Add let mappings here. source_vars, source_exprs = zip( *list(ctx['let_mappings'].items())) e = LetIn(source_vars, source_exprs, e) # Add new tegs here. for (new_dvar, (new_lb, new_ub)) in ctx['new_tegs']: e = Teg(new_lb, new_ub, e, new_dvar) # Add dependent mappings here. for new_vars, new_exprs in ctx.get('dependent_mappings', []): e = LetIn(new_vars, new_exprs, e) return e, ctx elif isinstance(e, LetIn): if len(ctx.get('teg_sources', {})) > 0: if (any([ new_var in map_expr for new_var in e.new_vars for map_vars, map_exprs in ctx.get( 'dependent_mappings', []) for map_expr in map_exprs ]) or any([ new_var in map_expr for new_var in e.new_vars for map_var, map_expr in ctx.get('let_mappings', {}).items() ])): # reparametrization is dependent on this let_map. lift this map. ctx['dependent_mappings'] = [ *ctx.get('dependent_mappings', []), (e.new_vars, e.new_exprs) ] return e.expr, ctx return e, ctx
def rewrite(delta, not_ctx=set()): # Extract polynomial coefficients. poly_set = extract_coefficients_from_polynomial( delta.expr, {(var.name, var.uid) for var in not_ctx}) unique_vars = [] for term in poly_set: for var, _ in term: if var is not CONST_VAR: unique_vars.append(var) x = unique_vars[0] y = unique_vars[1] c_xy = get_poly_term(poly_set, {x: 1, y: 1}) c_x = get_poly_term(poly_set, {x: 1}) c_y = get_poly_term(poly_set, {y: 1}) c_1 = get_poly_term(poly_set, {}) c_xy_var = Var(f'c_{x.name}_{y.name}') c_x_var = Var(f'c_{x.name}') c_y_var = Var(f'c_{y.name}') c_1_var = Var('c_1') coeff_vars = [c_xy_var, c_x_var, c_y_var, c_1_var] coeff_exprs = [ teg_abs(c_xy), IfElse(c_xy > 0, c_x, -c_x), IfElse(c_xy > 0, c_y, -c_y), IfElse(c_xy > 0, c_1, -c_1) ] sqrt_c_xy = Sqrt(c_xy_var) sqrt_c_xy_var = Var(f'{x.name}_{y.name}_sqrt') needs_transforms = (c_x != Const(0) or c_y != Const(0)) if needs_transforms: scale_map = partial(scale, scale=[sqrt_c_xy_var, sqrt_c_xy_var]) translate_map = partial( translate, translate=[c_y_var / sqrt_c_xy_var, c_x_var / sqrt_c_xy_var]) scaler, (x_s, y_s) = scale_map([x, y]) translater, (x_st, y_st) = translate_map([x_s, y_s]) sqr_constant = (c_x_var * c_y_var) / (c_xy_var) - c_1_var scale_jacobian = Const(1) else: x_st, y_st = x, y sqr_constant = -c_1_var / c_xy_var scale_jacobian = c_xy_var # If threshold is negative, the hyperbola is in the second and fourth quadrants. # Inverting either one of x or y automatically handles this. conditional_inverter, (x_st, ) = scale( [x_st], scale=[IfElse(sqr_constant > 0, 1, -1)]) adjusted_sqr_constant = teg_abs(sqr_constant) constant = Sqrt(adjusted_sqr_constant) # Hyperbolic transform hyp_a, hyp_t = TegVar('hyp_a'), TegVar('hyp_t') # Build bounds transfer expressions. pos_a_lb = teg_cases([Sqrt(x_st.lb() * y_st.lb()), Const(0)], [(x_st.lb() > 0) & (y_st.lb() > 0)]) pos_a_ub = teg_cases([Sqrt(x_st.ub() * y_st.ub()), Const(0)], [(x_st.ub() > 0) & (y_st.ub() > 0)]) neg_a_lb = teg_cases([-Sqrt(x_st.lb() * y_st.lb()), Const(0)], [(x_st.lb() < 0) & (y_st.lb() < 0)]) neg_a_ub = teg_cases([-Sqrt(x_st.ub() * y_st.ub()), Const(0)], [(x_st.ub() < 0) & (y_st.ub() < 0)]) pos_t_lb = teg_max( teg_cases([ teg_max(x_st.lb() / hyp_a, hyp_a / y_st.ub()), hyp_a / y_st.ub(), MIN_T ], [(y_st.ub() > 0) & (x_st.lb() > 0), y_st.ub() > 0]), MIN_T) pos_t_ub = teg_min( teg_cases([ teg_min(x_st.ub() / hyp_a, hyp_a / y_st.lb()), x_st.ub() / hyp_a, MAX_T ], [(x_st.ub() > 0) & (y_st.lb() > 0), x_st.ub() > 0]), MAX_T) neg_t_lb = teg_max( teg_cases([ teg_max(x_st.ub() / hyp_a, hyp_a / y_st.lb()), hyp_a / y_st.lb(), MIN_T ], [(y_st.lb() < 0) & (x_st.ub() < 0), y_st.lb() < 0]), MIN_T) neg_t_ub = teg_min( teg_cases([ teg_min(x_st.lb() / hyp_a, hyp_a / y_st.ub()), x_st.lb() / hyp_a, MAX_T ], [(x_st.lb() < 0) & (y_st.ub() < 0), x_st.lb() < 0]), MAX_T) pos_curve = BiMap(Delta(hyp_a - constant), sources=[x_st, y_st], source_exprs=[hyp_a * hyp_t, hyp_a / hyp_t], targets=[hyp_a, hyp_t], target_exprs=[Sqrt(x_st * y_st), Sqrt(x_st / y_st)], inv_jacobian=(hyp_a / hyp_t) * (1 / (constant * scale_jacobian)), target_lower_bounds=[pos_a_lb, pos_t_lb], target_upper_bounds=[pos_a_ub, pos_t_ub]) neg_curve = BiMap(Delta(hyp_a + constant), sources=[x_st, y_st], source_exprs=[hyp_a * hyp_t, hyp_a / hyp_t], targets=[hyp_a, hyp_t], target_exprs=[-Sqrt(x_st * y_st), Sqrt(x_st / y_st)], inv_jacobian=(-1 * hyp_a / hyp_t) * (1 / (constant * scale_jacobian)), target_lower_bounds=[neg_a_lb, neg_t_lb], target_upper_bounds=[neg_a_ub, neg_t_ub]) if needs_transforms: return LetIn( coeff_vars, coeff_exprs, LetIn([sqrt_c_xy_var], [sqrt_c_xy], scaler( translater( conditional_inverter(pos_curve + neg_curve))))) else: return LetIn(coeff_vars, coeff_exprs, conditional_inverter(pos_curve + neg_curve))