def if_stmt(cond, body, orelse, local_writes): """Functional form of an if statement. Args: cond: Callable with no arguments, predicate of conditional. body: Callable with no arguments, and outputs of the positive (if) branch as return type. orelse: Callable with no arguments, and outputs of the negative (else) branch as return type. local_writes: list(pyct.Variable), list of variables assigned in either body or orelse. Returns: Tuple containing the statement outputs. """ cond_result = cond() if isinstance(cond_result, z3.BoolRef): body_vals, _ = staging.execute_isolated(body, local_writes) orelse_vals, _ = staging.execute_isolated(orelse, local_writes) for body_result, else_result, modified_var in zip( body_vals, orelse_vals, local_writes): # Unlike e.g., TensorFlow, z3 does not do tracing on If statements. # Instead, it expects the results of the body and orelse branches passed # as values. As such, each result is the result of the deferred z3.If # statement. modified_var.val = z3.If(cond_result, body_result, else_result) else: py_defaults.if_stmt(lambda: cond_result, body, orelse, local_writes)
def for_body(iterate_index, *state): # pylint: disable=missing-docstring for var, s in zip(local_writes, state): var.val = s target.val = iter_[iterate_index] mods, _ = staging.execute_isolated(body, local_writes) state = [iterate_index + 1] + mods return state
def if_body(*_): modified_vals, _ = staging.execute_isolated(body, local_writes) return modified_vals
def while_body(*state): for var, s in zip(local_writes, state): var.val = s mods, _ = staging.execute_isolated(body, local_writes) return mods
def while_test(*state): for var, s in zip(local_writes, state): var.val = s _, retvals = staging.execute_isolated(cond, local_writes) return retvals
def if_orelse(*_): modified_vals, _ = staging.execute_isolated(orelse, local_writes) return modified_vals
def for_body(idx, state): for var, s in zip(modified_vars, state): var.val = s target.val = iter_[idx] modified_vals, _ = staging.execute_isolated(body, modified_vars) return modified_vals