Пример #1
0
Файл: ad.py Проект: 0x0is1/jax
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
    val_out = primitive.bind(*primals, **params)
    tangents_out = (rule(t, val_out, *primals, **params)
                    for rule, t in zip(jvprules, tangents)
                    if rule is not None and type(t) is not Zero)
    tangents_out = list(tangents_out)
    return val_out, functools.reduce(add_tangents, tangents_out,
                                     Zero.from_value(val_out))
Пример #2
0
Файл: ad.py Проект: jbampton/jax
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
  num_primals = len(nonzeros)
  primals = list(primals_and_nztangents[:num_primals])
  nonzero_tangents = iter(primals_and_nztangents[num_primals:])
  tangents = [next(nonzero_tangents) if nz else Zero.from_value(p)
              for p, nz in zip(primals, nonzeros)]
  primals_out, tangents_out = yield (primals, tangents), {}
  out_nonzeros = [type(t) is not Zero for t in tangents_out]
  nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero]
  yield list(primals_out) + nonzero_tangents_out, out_nonzeros
Пример #3
0
def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes,
                     primals_tangents_in, cotangents_in):
  primals_in  = [x for x in primals_tangents_in if not is_undefined_primal(x)]
  tangents_in = [x for x in primals_tangents_in if     is_undefined_primal(x)]
  res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)
  cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (),
                                  (*res, *tangents_in), cotangents_in)
  cotangents_out = iter(cotangents_out_[len(res):])
  outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x)
          for x in primals_tangents_in]
  assert next(cotangents_out, None) is None
  return outs
Пример #4
0
Файл: ad.py Проект: jbampton/jax
def jvpfun(instantiate, transform_stack, primals, tangents):
  tangents = [Zero.from_value(t) if not isinstance(t, Zero)
              and dtype(t) is float0 else t for t in tangents]
  ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
         else contextlib.nullcontext())
  with core.new_main(JVPTrace) as main, ctx:
    out_primals, out_tangents = yield (main, primals, tangents), {}
    del main
  if type(instantiate) is bool:
    instantiate = [instantiate] * len(out_tangents)
  out_tangents = [instantiate_zeros(t) if inst else t for t, inst
                  in zip(out_tangents, instantiate)]
  yield out_primals, out_tangents
Пример #5
0
def _matchaxis_symbolic_zeros(axis_name,
                              sz,
                              name,
                              src,
                              dst,
                              x,
                              sum_match=False):
    # Just like `matchaxis`, but handles symbolic zeros using ad_util.py
    # TODO(mattjj): dedup with matchaxis
    if isinstance(x, Zero):
        if src == dst:
            return x
        elif type(src) == type(dst) == int:
            aval = core.mapped_aval(sz, src, x.aval)
            return Zero(core.unmapped_aval(sz, name, dst, aval))
        elif src is not_mapped and dst is not not_mapped:
            return Zero(core.unmapped_aval(sz, name, dst, x.aval))
        elif dst is not_mapped and sum_match:
            return Zero(core.mapped_aval(sz, src, x.aval))
        else:
            raise ValueError((axis_name, x, src, dst))
    else:
        return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
Пример #6
0
def jvpfun(instantiate, primals, tangents):
    tangents = [
        Zero.from_value(t)
        if not isinstance(t, Zero) and dtype(t) is float0 else t
        for t in tangents
    ]
    with core.new_main(JVPTrace) as main:
        out_primals, out_tangents = yield (main, primals, tangents), {}
        del main
    if type(instantiate) is bool:
        instantiate = [instantiate] * len(out_tangents)
    out_tangents = [
        instantiate_zeros(t) if inst else t
        for t, inst in zip(out_tangents, instantiate)
    ]
    yield out_primals, out_tangents
Пример #7
0
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
    out_tree, res_tree = out_trees()
    res, cts_out = split_list(args, [res_tree.num_leaves])
    py_res = tree_unflatten(res_tree, res)
    py_cts_out = tree_unflatten(out_tree, cts_out)
    py_cts_in = yield (py_res, py_cts_out), {}
    # For each None in py_cts_in, indicating an argument for which the rule
    # produces no cotangent, we replace it with a pytree with the structure of the
    # corresponding subtree of in_tree and with leaves of a non-pytree sentinel
    # object, to be replaced with Nones in the final returned result.
    zero = object()  # non-pytree sentinel to replace Nones in py_cts_in
    dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
    cts_in_flat = []
    append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0]))
    try:
        if not isinstance(py_cts_in, tuple):
            raise ValueError
        tree_multimap(append_cts,
                      tuple(zero if ct is None else ct for ct in py_cts_in),
                      dummy)
    except ValueError:
        _, in_tree2 = tree_flatten(py_cts_in)
        msg = (
            "Custom VJP rule must produce an output with the same container "
            "(pytree) structure as the args tuple of the primal function, "
            "and in particular must produce a tuple of length equal to the "
            "number of arguments to the primal function, but got VJP output "
            "structure {} for primal input structure {}.")
        raise TypeError(msg.format(in_tree2, in_tree)) from None
    # Ignore any None cotangents, and any corresponding to inputs for which the
    # type doesn't equal the tangent type (i.e. float0s)
    # TODO(mattjj): change this to check if tangent type represents 0dim vspace
    yield [
        Zero(a.at_least_vspace())
        if ct is zero or a != a.at_least_vspace() else ct
        for a, ct in zip(in_avals, cts_in_flat)
    ]
Пример #8
0
Файл: ad.py Проект: jbampton/jax
 def unmap_zero(zero, in_axis):
   return (zero if in_axis is None else
           Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))
Пример #9
0
Файл: ad.py Проект: jbampton/jax
def zero_jvp(primitive, primals, tangents, **params):
  r = primitive.bind(*primals, **params)
  return r, Zero.from_value(r)
Пример #10
0
Файл: ad.py Проект: jbampton/jax
 def lift(self, val):
   tangent_zero = Zero(get_aval(val).at_least_vspace())
   return JVPTracer(self, val, tangent_zero)
Пример #11
0
Файл: ad.py Проект: jbampton/jax
 def read_cotangent(v):
   return ct_env.pop(v, Zero(v.aval))
Пример #12
0
Файл: ad.py Проект: jbampton/jax
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
                  consts, primals_in, cotangents_in):
  if all(type(ct) is Zero for ct in cotangents_in):
    return map(lambda v: Zero(v.aval), jaxpr.invars)

  def write_cotangent(prim, v, ct):
    # assert v not in primal_env
    assert ct is not Zero, (prim, v.aval)  # check for an old harmless type error
    if ct is None or type(v) is Literal:
      return
    if type(ct) is Zero:
      # FIXME: This triggers a lot of failures!
      # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
      return
    axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
                           if axis_name in core.get_aval(ct).named_shape
                           and axis_name not in v.aval.named_shape)
    if axes_to_reduce:
      ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
    ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
    if config.jax_enable_checks:
      ct_aval = core.get_aval(ct_env[v])
      joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
      assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)

  def read_cotangent(v):
    return ct_env.pop(v, Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if not is_undefined_primal(val):
      primal_env[v] = val

  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.constvars, consts)
  # FIXME: invars can contain both primal and tangent values, and this line
  #        forces primal_in to contain UndefinedPrimals for tangent values!
  map(write_primal, jaxpr.invars, primals_in)

  ct_env: Dict[Any, Any] = {}
  ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
         else contextlib.nullcontext())
  with ctx:
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn in jaxpr.eqns[::-1]:
      # FIXME: Some invars correspond to tangents
      invals = map(read_primal, eqn.invars)
      if eqn.primitive.multiple_results:
        cts_in = map(read_cotangent, eqn.outvars)
      else:
        cts_in, = map(read_cotangent, eqn.outvars)
      name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
      with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
        if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
          cts_in_avals = [v.aval for v in eqn.outvars]
          params = dict(eqn.params)
          call_jaxpr = params.pop('call_jaxpr')
          cts_out = get_primitive_transpose(eqn.primitive)(
              params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
        elif eqn.primitive in reducing_transposes:
          cts_out = reducing_transposes[eqn.primitive](
              reduce_axes, cts_in, *invals, **eqn.params)
        else:
          cts_out = get_primitive_transpose(eqn.primitive)(
              cts_in, *invals, **eqn.params)
        cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)

  cotangents_out = map(read_cotangent, jaxpr.invars)
  return cotangents_out
Пример #13
0
Файл: ad.py Проект: jbampton/jax
def recast_to_float0(primal, tangent):
  if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
    return Zero(get_aval(primal).at_least_vspace())
  else:
    return tangent
Пример #14
0
Файл: ad.py Проект: jbampton/jax
    return Zero if out is Zero else (out, None)
  else:
    out = rhs_rule(cotangent, x, **kwargs)
    return Zero if out is Zero else (None, out)


def defjvp_zero(primitive):
  assert isinstance(primitive, Primitive)
  primitive_jvps[primitive] = partial(zero_jvp, primitive)

def zero_jvp(primitive, primals, tangents, **params):
  r = primitive.bind(*primals, **params)
  return r, Zero.from_value(r)


deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)])
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))

def instantiate_zeros(tangent):
  if type(tangent) is Zero:
    return zeros_like_aval(tangent.aval)
  else:
    return tangent

# This function seems similar to instantiate_zeros, but it is sometimes used
# to instantiate zero abstract units with a different aval
def instantiate_zeros_aval(aval, tangent):
  if type(tangent) is Zero:
    assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval
    return zeros_like_aval(aval)
  else:
Пример #15
0
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
    if all(type(ct) is Zero for ct in cotangents_in):
        return map(lambda v: Zero(v.aval), jaxpr.invars)

    def write_cotangent(prim, v, ct):
        # assert v not in primal_env
        assert ct is not Zero, (prim, v.aval
                                )  # check for an old harmless type error
        if ct is None or type(v) is Literal:
            return
        if type(ct) is Zero:
            # FIXME: This triggers a lot of failures!
            # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
            return
        ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
        if config.jax_enable_checks:
            ct_aval = core.get_aval(ct_env[v])
            joined_aval = core.lattice_join(
                v.aval, ct_aval).strip_weak_type().strip_named_shape()
            assert v.aval.strip_weak_type().strip_named_shape(
            ) == joined_aval, (prim, v.aval, ct_aval)

    def read_cotangent(v):
        return ct_env.pop(v, Zero(v.aval))

    def read_primal(v):
        if type(v) is Literal:
            return v.val
        else:
            return primal_env.get(v, UndefinedPrimal(v.aval))

    def write_primal(v, val):
        if not is_undefined_primal(val):
            primal_env[v] = val

    primal_env: Dict[Any, Any] = {}
    write_primal(core.unitvar, core.unit)
    map(write_primal, jaxpr.constvars, consts)
    # FIXME: invars can contain both primal and tangent values, and this line
    #        forces primal_in to contain UndefinedPrimals for tangent values!
    map(write_primal, jaxpr.invars, primals_in)

    ct_env: Dict[Any, Any] = {}
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn in jaxpr.eqns[::-1]:
        # FIXME: Some invars correspond to tangents
        invals = map(read_primal, eqn.invars)
        if eqn.primitive.multiple_results:
            cts_in = map(read_cotangent, eqn.outvars)
        else:
            cts_in, = map(read_cotangent, eqn.outvars)
        with source_info_util.user_context(eqn.source_info):
            if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
                cts_in_avals = [v.aval for v in eqn.outvars]
                call_jaxpr, params = core.extract_call_jaxpr(
                    eqn.primitive, eqn.params)
                cts_out = get_primitive_transpose(eqn.primitive)(params,
                                                                 call_jaxpr,
                                                                 invals,
                                                                 cts_in,
                                                                 cts_in_avals)
            else:
                cts_out = get_primitive_transpose(eqn.primitive)(cts_in,
                                                                 *invals,
                                                                 **eqn.params)
        cts_out = [Zero(v.aval)
                   for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)

    cotangents_out = map(read_cotangent, jaxpr.invars)
    return cotangents_out