コード例 #1
0
ファイル: dispatch.py プロジェクト: jbampton/jax
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
                     kept_var_idx, *args):
  env = {core.unitvar: core.unit}
  pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
  map(env.setdefault, jaxpr.invars, pruned_args)
  map(env.setdefault, jaxpr.constvars, consts)
  outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
          for v in jaxpr.outvars]
  return [_copy_device_array_to_device(x, device) if device_array.type_is_device_array(x)
          else h(*device_put(x, device)) for h, x in zip(handlers, outs)]
コード例 #2
0
ファイル: dispatch.py プロジェクト: romanngg/jax
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
                     has_unordered_effects: bool,
                     ordered_effects: List[core.Effect], kept_var_idx, *args):
    env: Dict[core.Var, Any] = {}
    pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
    map(env.setdefault, jaxpr.invars, pruned_args)
    map(env.setdefault, jaxpr.constvars, consts)
    outs = [
        xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
        for v in jaxpr.outvars
    ]
    return [
        _copy_device_array_to_device(x, device)
        if device_array.type_is_device_array(x) else h(None,
                                                       *device_put(x, device))
        for h, x in zip(handlers, outs)
    ]
コード例 #3
0
ファイル: dispatch.py プロジェクト: jbampton/jax
def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
  x = xla.canonicalize_dtype(x)
  try:
    return device_put_handlers[type(x)](x, device)
  except KeyError as err:
    raise TypeError(f"No device_put handler for type: {type(x)}") from err
コード例 #4
0
 def read(v):
   if type(v) is core.Literal:
     return [xb.constant(c, xla.canonicalize_dtype(v.val))]
   else:
     return env[v]
コード例 #5
0
def _bdint_canoncalize_dtype(x):
  return BoundedInt(xla.canonicalize_dtype(x._val), x._bound)