示例#1
0
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr):
  assert consts is None and init is None
  assert type(xs) is tuple
  a, res = xs
  assert a is None and res is not None

  # jaxpr :: d -> c -> (a, res) ->  (c, b)
  # jaxpr_lifted :: res -> (d, c, a) -> (c, b)
  # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
  assert type(jaxpr.jaxpr.invars[2]) is tuple  # assume restructuring
  jaxpr_lifted = rearrange_binders(
      lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr)
  jaxpr_lifted_trans = _transpose_jaxpr(jaxpr_lifted)
  jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans)

  c_aval, b_aval = jaxpr.out_aval
  d_aval, c_aval2, _ = jaxpr.in_avals
  assert c_aval == c_aval2
  bs_aval = _promote_aval_rank(length, b_aval)
  ct_d = ad_util.zeros_like_aval(d_aval)
  ct_c, ct_bs = ad.instantiate_zeros_aval(core.AbstractTuple((c_aval, bs_aval)), ct)
  carry_ct = core.pack((ct_c, ct_d))

  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
  core.check_jaxpr(jaxpr_trans.jaxpr)
  unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals
  assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval
  assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval

  out = scan_p.bind(
      core.unit, carry_ct, core.pack((ct_bs, res)),
      forward=not forward, length=length, jaxpr=jaxpr_trans)
  (ct_init, ct_consts), ct_as = out
  return ct_consts, ct_init, (ct_as, None)
示例#2
0
  def test_lattice_join_named_shape(self):
    aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
    self.assertEqual(core.lattice_join(aval1, aval1), aval1)

    aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
    expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
    self.assertEqual(core.lattice_join(aval1, aval2), expected)

    aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5})
    self.assertRaises(TypeError, lambda: core.lattice_join(aval1, aval3))
示例#3
0
文件: ad.py 项目: jbampton/jax
 def write_cotangent(prim, v, ct):
   # assert v not in primal_env
   assert ct is not Zero, (prim, v.aval)  # check for an old harmless type error
   if ct is None or type(v) is Literal:
     return
   if type(ct) is Zero:
     # FIXME: This triggers a lot of failures!
     # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
     return
   axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
                          if axis_name in core.get_aval(ct).named_shape
                          and axis_name not in v.aval.named_shape)
   if axes_to_reduce:
     ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
   ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
   if config.jax_enable_checks:
     ct_aval = core.get_aval(ct_env[v])
     joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
     assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)
示例#4
0
def add_abstract(xs, ys):
    return lattice_join(xs, ys)
示例#5
0
def typecheck(aval, x):
  aval = raise_to_shaped(aval)
  try:
    return aval == core.lattice_join(aval, core.get_aval(x))
  except TypeError:
    return False
示例#6
0
def while_loop(cond_fun, body_fun, init_val):
  """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
  init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
  flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(body_fun), (in_tree,))
  flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree,))

  carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat)
  cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (carry_pval_flat,))
  body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(flat_body_fun, (carry_pval_flat,), instantiate=True)
  carry_aval_out, _ = body_pval_out
  assert isinstance(carry_aval_out, core.AbstractValue)
  assert carry_aval == core.lattice_join(carry_aval, carry_aval_out)

  cond_pv, cond_const = cond_pval_out
  if cond_pv is None:
    # cond_fun evaluates to a constant, so don't need to generate a while_loop
    if cond_const:
      raise ValueError("infinite loop with no effects")
    else:
      return init_val
  else:
    assert isinstance(cond_pv, core.AbstractValue)
    if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape
        or cond_pv.dtype != onp.bool_):
      msg = "while_loop cond_fun must return a scalar boolean, got {}."
      raise TypeError(msg.format(cond_pv))

  if out_tree() != in_tree:
    raise TypeError("body_fun input and output must have identical structure")
  out_flat = while_p.bind(
      init_val_flat, core.pack(cond_consts), core.pack(body_consts),
      aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
  return build_tree(out_tree(), out_flat)
示例#7
0
def while_loop(cond_fun, body_fun, init_val):
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
    init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
    flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(
        lu.wrap_init(body_fun), (in_tree, ))
    flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun),
                                                      (in_tree, ))

    carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat)
    cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr(
        flat_cond_fun, (carry_pval_flat, ))
    body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr(
        flat_body_fun, (carry_pval_flat, ), instantiate=True)
    carry_aval_out, _ = body_pval_out
    assert isinstance(carry_aval_out, core.AbstractValue)
    assert carry_aval == core.lattice_join(carry_aval, carry_aval_out)

    cond_pv, cond_const = cond_pval_out
    if cond_pv is None:
        # cond_fun evaluates to a constant, so don't need to generate a while_loop
        if cond_const:
            raise ValueError("infinite loop with no effects")
        else:
            return init_val
    else:
        assert isinstance(cond_pv, core.AbstractValue)
        if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape
                or cond_pv.dtype != onp.bool_):
            msg = "while_loop cond_fun must return a scalar boolean, got {}."
            raise TypeError(msg.format(cond_pv))

    # We don't want to promote literal constants as loop arguments; there are
    # sometimes many of them. We pass tracers as loop arguments, but leave
    # nontracers as constants. We also sort the constants so the nontracers are
    # first.
    def split_tracers_and_nontracers(jaxpr, consts):
        tracer = []
        nontracer = []
        for x in zip(jaxpr.constvars, consts):
            # TODO(phawkins): We avoid treating DeviceArrays as constant literals so
            # we don't copy large arrays back to the host. We probably should relax
            # this and either always copy small constants, or opportunistically use
            # DeviceArray values for which we already know npy_value.
            not_literal_const = isinstance(x[1],
                                           (core.Tracer, xla.DeviceArray))
            (tracer if not_literal_const else nontracer).append(x)
        tracer_vars, tracer_consts = unzip2(tracer)
        nontracer_vars, nontracer_consts = unzip2(nontracer)
        return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts

    cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
    cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
    body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
    body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split

    if out_tree() != in_tree:
        raise TypeError(
            "body_fun input and output must have identical structure")
    out_flat = while_p.bind(
        init_val_flat,
        core.pack(cond_tracer_consts),
        core.pack(body_tracer_consts),
        cond_consts=lax._OpaqueParam(cond_nontracer_consts),
        body_consts=lax._OpaqueParam(body_nontracer_consts),
        aval_out=carry_aval_out,
        cond_jaxpr=cond_jaxpr,
        body_jaxpr=body_jaxpr)
    return build_tree(out_tree(), out_flat)