コード例 #1
0
ファイル: ad.py プロジェクト: jbampton/jax
def linear_jvp(primitive, primals, tangents, **params):
  val_out = primitive.bind(*primals, **params)
  if all(type(tangent) is Zero for tangent in tangents):
    return val_out, Zero.from_value(val_out)
  else:
    tangents = map(instantiate_zeros, tangents)
    return val_out, primitive.bind(*tangents, **params)
コード例 #2
0
ファイル: ad.py プロジェクト: John1Tang/jax
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
    val_out = primitive.bind(*primals, **params)
    tangents_out = (rule(t, val_out, *primals, **params)
                    for rule, t in zip(jvprules, tangents)
                    if rule is not None and type(t) is not Zero)
    tangents_out = list(tangents_out)
    return val_out, functools.reduce(add_tangents, tangents_out,
                                     Zero.from_value(val_out))
コード例 #3
0
ファイル: ad.py プロジェクト: jbampton/jax
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
  num_primals = len(nonzeros)
  primals = list(primals_and_nztangents[:num_primals])
  nonzero_tangents = iter(primals_and_nztangents[num_primals:])
  tangents = [next(nonzero_tangents) if nz else Zero.from_value(p)
              for p, nz in zip(primals, nonzeros)]
  primals_out, tangents_out = yield (primals, tangents), {}
  out_nonzeros = [type(t) is not Zero for t in tangents_out]
  nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero]
  yield list(primals_out) + nonzero_tangents_out, out_nonzeros
コード例 #4
0
def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes,
                     primals_tangents_in, cotangents_in):
  primals_in  = [x for x in primals_tangents_in if not is_undefined_primal(x)]
  tangents_in = [x for x in primals_tangents_in if     is_undefined_primal(x)]
  res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)
  cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (),
                                  (*res, *tangents_in), cotangents_in)
  cotangents_out = iter(cotangents_out_[len(res):])
  outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x)
          for x in primals_tangents_in]
  assert next(cotangents_out, None) is None
  return outs
コード例 #5
0
ファイル: ad.py プロジェクト: jbampton/jax
def jvpfun(instantiate, transform_stack, primals, tangents):
  tangents = [Zero.from_value(t) if not isinstance(t, Zero)
              and dtype(t) is float0 else t for t in tangents]
  ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
         else contextlib.nullcontext())
  with core.new_main(JVPTrace) as main, ctx:
    out_primals, out_tangents = yield (main, primals, tangents), {}
    del main
  if type(instantiate) is bool:
    instantiate = [instantiate] * len(out_tangents)
  out_tangents = [instantiate_zeros(t) if inst else t for t, inst
                  in zip(out_tangents, instantiate)]
  yield out_primals, out_tangents
コード例 #6
0
def jvpfun(instantiate, primals, tangents):
    tangents = [
        Zero.from_value(t)
        if not isinstance(t, Zero) and dtype(t) is float0 else t
        for t in tangents
    ]
    with core.new_main(JVPTrace) as main:
        out_primals, out_tangents = yield (main, primals, tangents), {}
        del main
    if type(instantiate) is bool:
        instantiate = [instantiate] * len(out_tangents)
    out_tangents = [
        instantiate_zeros(t) if inst else t
        for t, inst in zip(out_tangents, instantiate)
    ]
    yield out_primals, out_tangents
コード例 #7
0
ファイル: ad.py プロジェクト: jbampton/jax
def zero_jvp(primitive, primals, tangents, **params):
  r = primitive.bind(*primals, **params)
  return r, Zero.from_value(r)
コード例 #8
0
ファイル: ad.py プロジェクト: jbampton/jax
    return Zero if out is Zero else (out, None)
  else:
    out = rhs_rule(cotangent, x, **kwargs)
    return Zero if out is Zero else (None, out)


def defjvp_zero(primitive):
  assert isinstance(primitive, Primitive)
  primitive_jvps[primitive] = partial(zero_jvp, primitive)

def zero_jvp(primitive, primals, tangents, **params):
  r = primitive.bind(*primals, **params)
  return r, Zero.from_value(r)


deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)])
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))

def instantiate_zeros(tangent):
  if type(tangent) is Zero:
    return zeros_like_aval(tangent.aval)
  else:
    return tangent

# This function seems similar to instantiate_zeros, but it is sometimes used
# to instantiate zero abstract units with a different aval
def instantiate_zeros_aval(aval, tangent):
  if type(tangent) is Zero:
    assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval
    return zeros_like_aval(aval)
  else: