Exemplo n.º 1
0
 def batched_cond_fun(batched_loop_carry):
   @lu.wrap_init
   def lifted(loop_carry, cond_consts):
     return core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry)
   f = batching.batch_transform(lifted, size, (init_val_bd, cond_consts_bd), 0)
   preds = f.call_wrapped((batched_loop_carry, cond_consts))
   return lax.reduce(preds, onp.array(False), lax.bitwise_or, [0])
Exemplo n.º 2
0
 def batched_body_fun(batched_loop_carry):
   @lu.wrap_init
   def lifted(loop_carry, cond_consts, body_consts):
     pred = core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry)
     new_loop_carry = core.eval_jaxpr(body_jaxpr, body_consts, (), loop_carry)
     return _jaxtupletree_select(pred, new_loop_carry, loop_carry)
   f = batching.batch_transform(
       lifted, size, (init_val_bd, cond_consts_bd, body_consts_bd), init_val_bd)
   return f.call_wrapped((batched_loop_carry, cond_consts, body_consts))