def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, kept_var_idx, *args): env = {core.unitvar: core.unit} pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx) map(env.setdefault, jaxpr.invars, pruned_args) map(env.setdefault, jaxpr.constvars, consts) outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v] for v in jaxpr.outvars] return [_copy_device_array_to_device(x, device) if device_array.type_is_device_array(x) else h(*device_put(x, device)) for h, x in zip(handlers, outs)]
def _device_put_impl(x, device: Optional[Device] = None): if device_array.type_is_device_array(x): return _copy_device_array_to_device(x, device) try: a = xla.abstractify(x) except TypeError as err: raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err return aval_to_result_handler(device, a)(*device_put(x, device))
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, has_unordered_effects: bool, ordered_effects: List[core.Effect], kept_var_idx, *args): env: Dict[core.Var, Any] = {} pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx) map(env.setdefault, jaxpr.invars, pruned_args) map(env.setdefault, jaxpr.constvars, consts) outs = [ xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v] for v in jaxpr.outvars ] return [ _copy_device_array_to_device(x, device) if device_array.type_is_device_array(x) else h(None, *device_put(x, device)) for h, x in zip(handlers, outs) ]