Exemplo n.º 1
0
    def test_dropvar_avals(self):
        def f(x):
            def body(c, _):
                return c, None

            (x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1)
            return [x2]

        aval = core.ShapedArray((), jnp.dtype('int32'))
        pval = pe.PartialVal.unknown(aval)
        jaxpr, _, _ = pe.trace_to_jaxpr_nounits(lu.wrap_init(f), [pval], False)
        dropvar, b = jaxpr.eqns[0].outvars
        self.assertEqual(dropvar.aval, aval)
Exemplo n.º 2
0
 def transposed(*args):
   in_primals, out_cts = tree_unflatten(treedef, args)
   in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
               pe.PartialVal.known(x) for x in in_primals]
   primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
   t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False)
   dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
   in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args,
                             out_cts)
   in_cts_ = iter(in_cts)
   in_cts = [next(in_cts_) if ad.is_undefined_primal(x)
             else ad_util.Zero(x.aval) for x in in_primals]
   assert next(in_cts_, None) is None
   in_cts, cell.treedef = tree_flatten(in_cts)
   return in_cts
Exemplo n.º 3
0
def linearize(traceable, *primals, **kwargs):
  has_aux = kwargs.pop('has_aux', False)
  if not has_aux:
    jvpfun = jvp(traceable)
  else:
    jvpfun, aux = jvp(traceable, has_aux=True)

  in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
              + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
                    for p in primals))
  _, in_tree = tree_flatten(((primals, primals), {}))
  jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
  assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
  out_primals_consts = [pval.get_known() for pval in out_primals_pvals]
  if not has_aux:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts
  else:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()