def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_typed_jaxpr, body_const_vals): # Simulate a pass-through false branch init_avals = safe_map(_BodyTracer.abstractify, init_vals) false_body_typed_jaxpr, false_body_const_vals, _ = ( lax_control_flow._initial_style_jaxpr(lambda *args: args, carried_tree, tuple(init_avals))) return lax_control_flow.cond_p.bind( *itertools.chain([self.pred], body_const_vals, init_vals, false_body_const_vals, init_vals), true_jaxpr=body_typed_jaxpr, false_jaxpr=false_body_typed_jaxpr)
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_closed_jaxpr, body_const_vals): # Simulate a pass-through false branch in_vals, in_tree = tree_util.tree_flatten( (body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals))) in_avals = safe_map(_BodyTracer.abstractify, in_vals) pass_through_closed_jaxpr, pass_through_const_vals, _ = ( lax_control_flow._initial_style_jaxpr( lambda *args: args[1], in_tree, tuple(in_avals))) assert len(pass_through_const_vals) == 0 args = list(itertools.chain(body_const_vals, init_vals)) return lax_control_flow.cond_p.bind( self.index, *args, branches=(pass_through_closed_jaxpr, body_closed_jaxpr), linear=(False,) * len(args))
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_typed_jaxpr, body_const_vals): # Trace the conditional function. cond_func takes 0 arguments, but # for lax.while we need a conditional function that takes the # carried_state_names. _initial_style_jaxpr will start its own trace and # will create tracers for all the carried state. We must put these values # in the scope._mutable_state before we trace the conditional # function. def cond_func_wrapped(*args): assert len(args) == len(carried_state_names) for ms, init_ms in zip(carried_state_names, args): scope._mutable_state[ms] = init_ms res = self.cond_func() # Conditional function is not allowed to modify the scope state for ms, init_ms in zip(carried_state_names, args): if not (scope._mutable_state[ms] is init_ms): msg = "Conditional function modifies scope.{} field." raise ValueError(msg.format(ms)) return res init_avals = safe_map(_BodyTracer.abstractify, init_vals) cond_jaxpr, cond_consts, cond_tree = ( lax_control_flow._initial_style_jaxpr(cond_func_wrapped, carried_tree, tuple(init_avals))) # TODO: share these checks with lax_control_flow.while if not tree_util.treedef_is_leaf(cond_tree): msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) if cond_jaxpr.out_avals != [ abstract_arrays.ShapedArray((), onp.bool_) ]: msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) return lax_control_flow.while_p.bind(*itertools.chain( cond_consts, body_const_vals, init_vals), cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_const_vals), body_jaxpr=body_typed_jaxpr)