Exemplo n.º 1
0
 def build_output_vals(self, scope, carried_state_names, carried_tree,
                       init_vals, body_typed_jaxpr, body_const_vals):
   # Simulate a pass-through false branch
   init_avals = safe_map(_BodyTracer.abstractify, init_vals)
   false_body_typed_jaxpr, false_body_const_vals, _ = (
     lax_control_flow._initial_style_jaxpr(lambda *args: args,
                                           carried_tree,
                                           tuple(init_avals)))
   return lax_control_flow.cond_p.bind(
     *itertools.chain([self.pred], body_const_vals,
                      init_vals, false_body_const_vals, init_vals),
     true_jaxpr=body_typed_jaxpr, false_jaxpr=false_body_typed_jaxpr)
Exemplo n.º 2
0
 def build_output_vals(self, scope, carried_state_names, carried_tree,
                       init_vals, body_closed_jaxpr, body_const_vals):
   # Simulate a pass-through false branch
   in_vals, in_tree = tree_util.tree_flatten(
       (body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals)))
   in_avals = safe_map(_BodyTracer.abstractify, in_vals)
   pass_through_closed_jaxpr, pass_through_const_vals, _ = (
     lax_control_flow._initial_style_jaxpr(
         lambda *args: args[1],
         in_tree,
         tuple(in_avals)))
   assert len(pass_through_const_vals) == 0
   args = list(itertools.chain(body_const_vals, init_vals))
   return lax_control_flow.cond_p.bind(
       self.index, *args,
       branches=(pass_through_closed_jaxpr, body_closed_jaxpr),
       linear=(False,) * len(args))
Exemplo n.º 3
0
    def build_output_vals(self, scope, carried_state_names, carried_tree,
                          init_vals, body_typed_jaxpr, body_const_vals):
        # Trace the conditional function. cond_func takes 0 arguments, but
        # for lax.while we need a conditional function that takes the
        # carried_state_names. _initial_style_jaxpr will start its own trace and
        # will create tracers for all the carried state. We must put these values
        # in the scope._mutable_state before we trace the conditional
        # function.
        def cond_func_wrapped(*args):
            assert len(args) == len(carried_state_names)
            for ms, init_ms in zip(carried_state_names, args):
                scope._mutable_state[ms] = init_ms
            res = self.cond_func()
            # Conditional function is not allowed to modify the scope state
            for ms, init_ms in zip(carried_state_names, args):
                if not (scope._mutable_state[ms] is init_ms):
                    msg = "Conditional function modifies scope.{} field."
                    raise ValueError(msg.format(ms))
            return res

        init_avals = safe_map(_BodyTracer.abstractify, init_vals)
        cond_jaxpr, cond_consts, cond_tree = (
            lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
                                                  carried_tree,
                                                  tuple(init_avals)))
        # TODO: share these checks with lax_control_flow.while
        if not tree_util.treedef_is_leaf(cond_tree):
            msg = "cond_fun must return a boolean scalar, but got pytree {}."
            raise TypeError(msg.format(cond_tree))
        if cond_jaxpr.out_avals != [
                abstract_arrays.ShapedArray((), onp.bool_)
        ]:
            msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
            raise TypeError(msg.format(cond_jaxpr.out_avals))

        return lax_control_flow.while_p.bind(*itertools.chain(
            cond_consts, body_const_vals, init_vals),
                                             cond_nconsts=len(cond_consts),
                                             cond_jaxpr=cond_jaxpr,
                                             body_nconsts=len(body_const_vals),
                                             body_jaxpr=body_typed_jaxpr)