예제 #1
0
def _matchaxis_symbolic_zeros(sz, name, src, dst, x, sum_match=False):
    # Just like `matchaxis`, but handles symbolic zeros using ad_util.py
    if isinstance(x, Zero):
        if src == dst:
            return x
        elif type(src) == type(dst) == int:
            aval = core.mapped_aval(sz, src, x.aval)
            return Zero(core.unmapped_aval(sz, name, dst, aval))
        elif src is not_mapped and dst is not not_mapped:
            return Zero(core.unmapped_aval(sz, name, dst, x.aval))
        elif dst is not_mapped and sum_match:
            return Zero(core.mapped_aval(sz, src, x.aval))
        else:
            raise ValueError((x, src, dst))
    else:
        return matchaxis(sz, src, dst, x, sum_match=sum_match)
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
  out_tree, res_tree = out_trees()
  assert len(args) == res_tree.num_leaves + out_tree.num_leaves
  res, cts_out = split_list(args, [res_tree.num_leaves])
  py_res = tree_unflatten(res_tree, res)
  py_cts_out = tree_unflatten(out_tree, cts_out)
  py_cts_in = yield (py_res, py_cts_out), {}
  # For each None in py_cts_in, indicating an argument for which the rule
  # produces no cotangent, we replace it with a pytree with the structure of the
  # corresponding subtree of in_tree and with leaves of a non-pytree sentinel
  # object, to be replaced with Nones in the final returned result.
  zero = object()  # non-pytree sentinel to replace Nones in py_cts_in
  dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
  cts_in_flat = []
  append = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0])) or x
  try:
    if not isinstance(py_cts_in, tuple):
      raise ValueError
    tree_map(append,
             tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
  except ValueError:
    _, in_tree2 = tree_flatten(py_cts_in)
    msg = ("Custom VJP rule must produce an output with the same container "
           "(pytree) structure as the args tuple of the primal function, "
           "and in particular must produce a tuple of length equal to the "
           "number of arguments to the primal function, but got VJP output "
           "structure {} for primal input structure {}.")
    raise TypeError(msg.format(in_tree2, in_tree)) from None
  # Ignore any None cotangents, and any corresponding to inputs for which the
  # type doesn't equal the tangent type (i.e. float0s)
  # TODO(mattjj): change this to check if tangent type represents 0dim vspace
  yield [Zero(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace()
         else ct for a, ct in zip(in_avals, cts_in_flat)]
예제 #3
0
파일: ad.py 프로젝트: jbampton/jax
 def unmap_zero(zero, in_axis):
   return (zero if in_axis is None else
           Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))
예제 #4
0
파일: ad.py 프로젝트: jbampton/jax
 def lift(self, val):
   tangent_zero = Zero(get_aval(val).at_least_vspace())
   return JVPTracer(self, val, tangent_zero)
예제 #5
0
파일: ad.py 프로젝트: jbampton/jax
 def read_cotangent(v):
   return ct_env.pop(v, Zero(v.aval))
예제 #6
0
파일: ad.py 프로젝트: jbampton/jax
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
                  consts, primals_in, cotangents_in):
  if all(type(ct) is Zero for ct in cotangents_in):
    return map(lambda v: Zero(v.aval), jaxpr.invars)

  def write_cotangent(prim, v, ct):
    # assert v not in primal_env
    assert ct is not Zero, (prim, v.aval)  # check for an old harmless type error
    if ct is None or type(v) is Literal:
      return
    if type(ct) is Zero:
      # FIXME: This triggers a lot of failures!
      # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
      return
    axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
                           if axis_name in core.get_aval(ct).named_shape
                           and axis_name not in v.aval.named_shape)
    if axes_to_reduce:
      ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
    ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
    if config.jax_enable_checks:
      ct_aval = core.get_aval(ct_env[v])
      joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
      assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)

  def read_cotangent(v):
    return ct_env.pop(v, Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if not is_undefined_primal(val):
      primal_env[v] = val

  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.constvars, consts)
  # FIXME: invars can contain both primal and tangent values, and this line
  #        forces primal_in to contain UndefinedPrimals for tangent values!
  map(write_primal, jaxpr.invars, primals_in)

  ct_env: Dict[Any, Any] = {}
  ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
         else contextlib.nullcontext())
  with ctx:
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn in jaxpr.eqns[::-1]:
      # FIXME: Some invars correspond to tangents
      invals = map(read_primal, eqn.invars)
      if eqn.primitive.multiple_results:
        cts_in = map(read_cotangent, eqn.outvars)
      else:
        cts_in, = map(read_cotangent, eqn.outvars)
      name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
      with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
        if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
          cts_in_avals = [v.aval for v in eqn.outvars]
          params = dict(eqn.params)
          call_jaxpr = params.pop('call_jaxpr')
          cts_out = get_primitive_transpose(eqn.primitive)(
              params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
        elif eqn.primitive in reducing_transposes:
          cts_out = reducing_transposes[eqn.primitive](
              reduce_axes, cts_in, *invals, **eqn.params)
        else:
          cts_out = get_primitive_transpose(eqn.primitive)(
              cts_in, *invals, **eqn.params)
        cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)

  cotangents_out = map(read_cotangent, jaxpr.invars)
  return cotangents_out
예제 #7
0
파일: ad.py 프로젝트: jbampton/jax
def recast_to_float0(primal, tangent):
  if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
    return Zero(get_aval(primal).at_least_vspace())
  else:
    return tangent
예제 #8
0
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
    if all(type(ct) is Zero for ct in cotangents_in):
        return map(lambda v: Zero(v.aval), jaxpr.invars)

    def write_cotangent(prim, v, ct):
        # assert v not in primal_env
        assert ct is not Zero, (prim, v.aval
                                )  # check for an old harmless type error
        if ct is None or type(v) is Literal:
            return
        if type(ct) is Zero:
            # FIXME: This triggers a lot of failures!
            # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
            return
        ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
        if config.jax_enable_checks:
            ct_aval = core.get_aval(ct_env[v])
            joined_aval = core.lattice_join(
                v.aval, ct_aval).strip_weak_type().strip_named_shape()
            assert v.aval.strip_weak_type().strip_named_shape(
            ) == joined_aval, (prim, v.aval, ct_aval)

    def read_cotangent(v):
        return ct_env.pop(v, Zero(v.aval))

    def read_primal(v):
        if type(v) is Literal:
            return v.val
        else:
            return primal_env.get(v, UndefinedPrimal(v.aval))

    def write_primal(v, val):
        if not is_undefined_primal(val):
            primal_env[v] = val

    primal_env: Dict[Any, Any] = {}
    write_primal(core.unitvar, core.unit)
    map(write_primal, jaxpr.constvars, consts)
    # FIXME: invars can contain both primal and tangent values, and this line
    #        forces primal_in to contain UndefinedPrimals for tangent values!
    map(write_primal, jaxpr.invars, primals_in)

    ct_env: Dict[Any, Any] = {}
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn in jaxpr.eqns[::-1]:
        # FIXME: Some invars correspond to tangents
        invals = map(read_primal, eqn.invars)
        if eqn.primitive.multiple_results:
            cts_in = map(read_cotangent, eqn.outvars)
        else:
            cts_in, = map(read_cotangent, eqn.outvars)
        with source_info_util.user_context(eqn.source_info):
            if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
                cts_in_avals = [v.aval for v in eqn.outvars]
                call_jaxpr, params = core.extract_call_jaxpr(
                    eqn.primitive, eqn.params)
                cts_out = get_primitive_transpose(eqn.primitive)(params,
                                                                 call_jaxpr,
                                                                 invals,
                                                                 cts_in,
                                                                 cts_in_avals)
            else:
                cts_out = get_primitive_transpose(eqn.primitive)(cts_in,
                                                                 *invals,
                                                                 **eqn.params)
        cts_out = [Zero(v.aval)
                   for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)

    cotangents_out = map(read_cotangent, jaxpr.invars)
    return cotangents_out