Ejemplo n.º 1
0
def unzip_eval_wrapper(pvs, *consts):
  """Function transformation that returns init/apply jaxprs and metadata."""
  args = (safe_map(pe.PartialVal, safe_zip(pvs, consts)),)
  success, result = yield args, {}
  if success:
    init_out, apply_out, pvals, metadata = result
    init_jaxpr, init_consts, init_env = init_out
    apply_jaxpr, apply_consts, apply_env = apply_out
    init_pvals, apply_pvals = pvals
    init_pvs, init_pv_consts = jax_util.unzip2(init_pvals)
    apply_pvs, apply_pv_consts = jax_util.unzip2(apply_pvals)

    out = (
        tuple(init_pv_consts) + tuple(init_consts) + tuple(apply_pv_consts) +
        tuple(apply_consts))
    yield out, (success, len(out),
                ((init_pvs, len(init_consts), apply_pvs),
                 (init_jaxpr, apply_jaxpr),
                 (init_env, apply_env),
                 metadata))
  else:
    jaxpr, (out_pvals, out_keys, consts, env) = result
    out_pvs, out_consts = jax_util.unzip2(out_pvals)
    out = tuple(out_consts) + tuple(consts)
    yield out, (success, len(out), (out_pvs, out_keys, jaxpr, env))
Ejemplo n.º 2
0
    def trace_to_jaxpr_finalize(in_tracers,
                                out_tracers,
                                trace,
                                instantiate=True):
        # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
        instantiate = [instantiate] * len(out_tracers)
        out_tracers = safe_map(trace.full_raise,
                               safe_map(core.full_lower, out_tracers))
        out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                               instantiate, out_tracers)
        jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
        out_pvals = [t.pval for t in out_tracers]
        # TODO: this is from partial_eval.trace_to_jaxpr. Share.
        assert not env

        # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
        out_avals = safe_map(abstract_arrays.raise_to_shaped,
                             unzip2(out_pvals)[0])
        const_avals = tuple(
            abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts)

        in_pvals = [t.pval for t in in_tracers]
        in_avals = tuple(
            safe_map(abstract_arrays.raise_to_shaped,
                     unzip2(in_pvals)[0]))

        typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (),
                                      const_avals + in_avals, out_avals)
        return typed_jaxpr, consts
Ejemplo n.º 3
0
def dot_general_dependency_rule(outstart, outcount, lhs, rhs,
                                dimension_numbers, precision):
    if not is_ones(outcount):
        raise NotImplementedError
    outshape = outcount.shape
    outslices = list(zip(outstart, outshape))
    (lhs_contracting, rhs_contracting), (lhs_batch,
                                         rhs_batch) = dimension_numbers
    lhs_other_out_dims = list(
        range(len(lhs_batch),
              len(lhs.shape) - len(lhs_contracting)))
    rhs_other_out_dims = list(
        range(len(rhs_batch) + len(lhs_other_out_dims), len(outshape)))
    lhs_outstart, lhs_outshape = unzip2(
        [outslices[d] for d in list(lhs_batch) + lhs_other_out_dims])
    (lhs_box, ), (lhs_count, ), _ = reduce_dependency_rule(None)(
        lhs_outstart, Ones(lhs_outshape), lhs, axes=lhs_contracting)
    rhs_outstart, rhs_outshape = unzip2(
        [outslices[d] for d in list(rhs_batch) + rhs_other_out_dims])
    (rhs_box, ), (rhs_count, ), _ = reduce_dependency_rule(None)(
        rhs_outstart, Ones(rhs_outshape), rhs, axes=rhs_contracting)
    incounts = [
        materialize(lhs_count) * prod(np.take(outshape, rhs_other_out_dims))
        if isinstance(lhs, LazyArray) else None,
        materialize(rhs_count) * prod(np.take(outshape, lhs_other_out_dims))
        if isinstance(rhs, LazyArray) else None
    ]
    return ([lhs_box, rhs_box], incounts, lambda *inslices: lax.dot_general(
        *inslices, dimension_numbers, precision))
Ejemplo n.º 4
0
 def split_tracers_and_nontracers(jaxpr, consts):
     tracer = []
     nontracer = []
     for x in zip(jaxpr.constvars, consts):
         # TODO(phawkins): We avoid treating DeviceArrays as constant literals so
         # we don't copy large arrays back to the host. We probably should relax
         # this and either always copy small constants, or opportunistically use
         # DeviceArray values for which we already know npy_value.
         not_literal_const = isinstance(x[1],
                                        (core.Tracer, xla.DeviceArray))
         (tracer if not_literal_const else nontracer).append(x)
     tracer_vars, tracer_consts = unzip2(tracer)
     nontracer_vars, nontracer_consts = unzip2(nontracer)
     return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts
Ejemplo n.º 5
0
Archivo: jet.py Proyecto: sts-sadr/jax
 def process_call(self, call_primitive, f, tracers, params):
   primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
   primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
   f_jet, out_tree_def = traceable(jet_subtrace(f, self.master), in_tree_def)
   result = call_primitive.bind(f_jet, *primals_and_series, **params)
   primals_out, series_out = tree_unflatten(out_tree_def(), result)
   return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
Ejemplo n.º 6
0
 def wrapped(*args, **kwargs):
     """Function wrapper that takes in inverse arguments."""
     forward_args = trace_args if len(trace_args) else args
     jaxpr, (in_tree, _) = trace_util.stage(f, dynamic=False)(*forward_args,
                                                              **kwargs)
     flat_forward_args, _ = tree_util.tree_flatten(forward_args)
     flat_args, _ = tree_util.tree_flatten(args)
     flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals)
     flat_forward_avals = [
         trace_util.get_shaped_aval(arg) for arg in flat_forward_args
     ]
     flat_incells = [
         InverseAndILDJ.unknown(aval) for aval in flat_forward_avals
     ]
     flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
     env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
                               flat_constcells, flat_incells, flat_outcells)
     flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
     if any(not flat_incell.top() for flat_incell in flat_incells):
         raise ValueError('Cannot invert function.')
     flat_vals, flat_ildjs = jax_util.unzip2([
         (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells
     ])
     vals = tree_util.tree_unflatten(in_tree, flat_vals)
     if reduce_ildj:
         ildj_ = sum(np.sum(i) for i in flat_ildjs)
     else:
         ildj_ = tree_util.tree_unflatten(in_tree, flat_ildjs)
     if len(forward_args) == 1:
         vals = vals[0]
         ildj_ = ildj_ if reduce_ildj else ildj_[0]
     return vals, ildj_
Ejemplo n.º 7
0
    def testScanVmapTuples(self):
        def f(c, a):
            a1, a2 = a
            c1, c2 = c
            b = np.sum(np.cos(a1)) * np.sum(np.tan(c2 * a2))
            c = c1 * np.sin(np.sum(a1 * a2)), c2 * np.cos(np.sum(a1))
            return c, b

        in_axes = (0, (1, 2))

        r = onp.random.RandomState(0)
        as_ = (r.randn(3, 7), r.randn(3, 4, 7))
        c = (r.randn(7, 2), r.randn(7))

        expected_c_out, expected_bs = [], []
        for i in range(7):
            c_out, bs = lax.scan(f, (c[0][i], c[1][i]),
                                 (as_[0][:, i], as_[1][:, :, i]))
            expected_c_out.append(c_out)
            expected_bs.append(bs)
        expected_c_out_0, expected_c_out_1 = unzip2(expected_c_out)
        expected_c_out = (np.stack(expected_c_out_0),
                          np.stack(expected_c_out_1))
        expected_bs = np.stack(expected_bs)
        expected = expected_c_out, expected_bs

        ans = api.vmap(lambda c, as_: lax.scan(f, c, as_), in_axes)(c, as_)
        self.assertAllClose(ans, expected, check_dtypes=False)
Ejemplo n.º 8
0
def custom_layer_cau_batch(trace, f, tracers, params):
    """Batching rule for layer_cau primitive to handle custom layers."""
    vals, dims = jax_util.unzip2((t.val, t.batch_dim) for t in tracers)
    if all(dim is batching.not_mapped for dim in dims):
        return layer_cau_p.bind(f, *vals, **params)
    args = tree_util.tree_unflatten(params['in_tree'], vals)
    dims_ = [not_mapped if dim is None else dim for dim in dims]
    layer, args = args[0], args[1:]
    if hasattr(layer, '_call_and_update_batched'):
        num_params = len(tree_util.tree_leaves(layer))
        layer_dims, arg_dims = dims_[:num_params], dims_[num_params:]
        if params['kwargs']['has_rng']:
            rng, args = args[0], args[1:]
            rng_dim, arg_dims = arg_dims[0], arg_dims[1:]
        mapping_over_layer = all(layer_dim is not not_mapped
                                 for layer_dim in layer_dims)
        mapping_over_args = all(arg_dim is not not_mapped
                                for arg_dim in arg_dims)
        assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims)
        if not mapping_over_layer and mapping_over_args:
            if params['kwargs']['has_rng']:
                if rng_dim is not not_mapped:
                    arg_dims = tuple(None if dim is not_mapped else dim
                                     for dim in arg_dims)
                    map_fun = jax.vmap(
                        lambda layer, rng, *args: _layer_cau_batched(
                            layer,
                            rng,
                            *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                            **params['kwargs']),
                        in_axes=(None, rng_dim) + (None, ) * len(arg_dims))
                else:
                    map_fun = lambda layer, *args: _layer_cau_batched(
                        layer,
                        *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                        **params['kwargs'])
                vals_out, update_out = map_fun(layer, rng, *args)
            else:
                vals_out, update_out = _layer_cau_batched(
                    layer, *args, **params['kwargs'])
            vals_out = tree_util.tree_leaves(vals_out)
            update_out = tree_util.tree_leaves(update_out)
            assert all(dim == 0 for dim in arg_dims)
            # Assume dimensions out are consistent
            dims_out = (0, ) * len(vals_out)
            dims_update = (None, ) * len(update_out)
            dims_out = dims_out + dims_update

            # Call wrapped function to avoid linear_util error
            f.call_wrapped(*tracers)
            return [
                batching.BatchTracer(trace, v, d)
                for v, d in zip(vals_out + update_out, dims_out + dims_update)
            ]
    f, dims_out = batching.batch_subtrace(f, trace.master, dims)
    vals_out = layer_cau_p.subcall('batch').bind(f, *vals, **params)
    return [
        batching.BatchTracer(trace, v, d)
        for v, d in zip(vals_out, dims_out())
    ]
Ejemplo n.º 9
0
Archivo: jet.py Proyecto: nhanwei/jax
def jet_subtrace(main, primals, series):
    trace = JetTrace(main, core.cur_sublevel())
    in_tracers = map(partial(JetTracer, trace), primals, series)
    ans = yield in_tracers, {}
    out_tracers = map(trace.full_raise, ans)
    out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
    yield out_primals, out_terms
Ejemplo n.º 10
0
def _make_typed_jaxpr(traceable, in_avals):
    pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
    jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable,
                                                 pvals,
                                                 instantiate=True)
    out_avals, _ = unzip2(pvals_out)
    return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
Ejemplo n.º 11
0
Archivo: jet.py Proyecto: romanodev/jax
def jet_transform(primals, series):
  with core.new_master(JetTrace) as master:
    trace = JetTrace(master, core.cur_sublevel())
    in_tracers = map(partial(JetTracer, trace), primals, series)
    ans = yield in_tracers, {}
    out_tracers = map(trace.full_raise, ans)
    out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
  yield out_primals, out_terms
Ejemplo n.º 12
0
def doubling_subtrace(main, heads, tails):
  trace = DoublingTrace(main, core.cur_sublevel())
  in_tracers = [DoublingTracer(trace, h, t) if t is not None else h
                for h, t in zip(heads, tails)]
  ans = yield in_tracers, {}
  out_tracers = map(trace.full_raise, ans)
  yield unzip2([(out_tracer.head, out_tracer.tail)
                for out_tracer in out_tracers])
Ejemplo n.º 13
0
def _initial_style_jaxpr(fun, in_tree, in_avals):
  in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
  fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
  out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
  const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
  typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
                                (), const_avals + in_avals, out_avals)
  return typed_jaxpr, consts, out_tree()
Ejemplo n.º 14
0
 def new_f(*args, **kwargs):
     axis_names, shape = unzip2(named_shape)
     size = np.prod(shape)
     local_devices = list(jax.local_devices())
     if len(local_devices) < size:
         raise SkipTest(f"Test requires {size} local devices")
     mesh_devices = np.array(local_devices[:size]).reshape(shape)
     with mesh(mesh_devices, axis_names):
         return f(*args, **kwargs)
Ejemplo n.º 15
0
def _scan_partial_eval(trace, *tracers, **kwargs):
    jaxpr = kwargs.pop('jaxpr')
    length = kwargs.pop('length')
    forward = kwargs.pop('forward')
    assert not kwargs
    in_pvs, _ = unzip2([t.pval for t in tracers])
    sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)

    sc_carry = sc_init
    for i in range(1000):
        second_components = (sc_consts, sc_carry, sc_xs)
        jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr,
                                                         second_components,
                                                         instantiate=(sc_carry,
                                                                      False))
        sc_carry_out, sc_ys = sc_out
        if sc_carry_out == sc_carry:
            break
        else:
            sc_carry = _binary_lattice_join(sc_carry, sc_carry_out)
    else:
        raise FixedPointError

    consts_tracer, init_tracer, xs_tracer = tracers
    lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry)
    lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer
    in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers])

    carry_aval, y_aval = jaxpr.out_aval
    ys_aval = _promote_aval_rank(length, y_aval)
    out_aval = core.AbstractTuple((carry_aval, ys_aval))
    out_pv = _put_known_pvs(sc_out, out_aval)

    out_carry, (ys, residuals) = scan_p.bind(*in_consts,
                                             forward=forward,
                                             length=length,
                                             jaxpr=jaxpr_1)
    out_const = core.pack((out_carry, ys))
    residuals_tracer = trace.new_instantiated_const(core.pack(residuals))
    d, c, a = lifted_tracers
    new_tracers = (d, c, (a, residuals_tracer))
    eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False,
                        dict(forward=forward, length=length, jaxpr=jaxpr_2))
    return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
Ejemplo n.º 16
0
Archivo: jet.py Proyecto: yangliuy/jax
 def post_process_call(self, call_primitive, out_tracers, params):
   primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
   out, treedef = tree_flatten((primals, series))
   del primals, series
   master = self.master
   def todo(x):
     primals, series = tree_unflatten(treedef, x)
     trace = JetTrace(master, core.cur_sublevel())
     return map(partial(JetTracer, trace), primals, series)
   return out, todo
Ejemplo n.º 17
0
 def process_call(self, call_primitive, f, tracers, params):
     if call_primitive in pe.map_primitives:
         raise NotImplementedError
     vals, net_params = unzip2((t.val, t.net_params) for t in tracers)
     if any(net_params):
         net_params = merge_params(net_params)
         f = apply_subtrace(f, self.master, WrapHashably(net_params))
         val_out = call_primitive.bind(f, *vals, **params)
         return ApplyTracer(self, net_params, val_out)
     else:
         return call_primitive.bind(f, *vals, **params)
Ejemplo n.º 18
0
 def process_primitive(self, primitive, tracers, params):
     vals_in, net_params = unzip2((t.val, t.net_params) for t in tracers)
     net_params = merge_params(net_params)
     if isinstance(primitive, Layer):
         apply_fun = primitive.apply_fun
         layer_params = net_params[primitive.name]
         return ApplyTracer(self, net_params,
                            apply_fun(layer_params, *vals_in))
     else:
         return ApplyTracer(self, net_params,
                            primitive.bind(*vals_in, **params))
Ejemplo n.º 19
0
 def tree_update(i, grad_tree, opt_state):
   states_flat, tree, subtrees = opt_state
   grad_flat = flatten(grad_tree)
   # if tree2 != tree:
   #   msg = ("optimizer update function was passed a gradient tree that did "
   #          "not match the parameter tree structure with which it was "
   #          "initialized: parameter tree {} and grad tree {}.")
   #   raise TypeError(msg.format(tree, tree2))
   states = map(pack_sequence_as, subtrees, states_flat)
   states_ = []
   for i in range(len(states[0])):
     states_.append((states[0][i], states[1][i]))
   new_states = map(partial(update, i), grad_flat, states_)
   new_states_flat = unzip2(map(flatten, new_states))
   # for subtree, subtree2 in zip(subtrees, subtrees2):
   #   if subtree2 != subtree:
   #     msg = ("optimizer update function produced an output structure that "
   #            "did not match its input structure: input {} and output {}.")
   #     raise TypeError(msg.format(subtree, subtree2))
   return OptimizerState(new_states_flat, tree, unzip2(new_states))
Ejemplo n.º 20
0
Archivo: jet.py Proyecto: romanodev/jax
 def process_primitive(self, primitive, tracers, params):
   primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
   order, = {len(terms) for terms in series_in if terms is not zero_series}
   series_in = [[zero_term] * order if s is zero_series else s
                for s in series_in]
   # TODO(mattjj): avoid always instantiating zeros
   series_in = [[onp.zeros(onp.shape(x), dtype=onp.result_type(x))
                 if t is zero_term else t for t in series]
                for x, series in zip(primals_in, series_in)]
   rule = jet_rules[primitive]
   primal_out, terms_out = rule(primals_in, series_in, **params)
   return JetTracer(self, primal_out, terms_out)
Ejemplo n.º 21
0
 def process_call(self, call_primitive, f, tracers, params):
   assert call_primitive.multiple_results
   heads, tails = unzip2((t.head, t.tail) for t in tracers)
   nonzero_tails, in_tree_def = tree_flatten(tails)
   f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main),
                                         len(heads), in_tree_def)
   name = params.get('name', f.__name__)
   new_params = dict(params, name=wrap_name(name, 'doubledouble'),
                     donated_invars=(False,) * (len(heads) + len(nonzero_tails)))
   result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params)
   heads_out, tails_out = tree_unflatten(out_tree_def(), result)
   return [DoublingTracer(self, h, t) for h, t in zip(heads_out, tails_out)]
Ejemplo n.º 22
0
def _ppermute_translation_rule(c, x, replica_groups, perm, platform=None):
  group_size = len(replica_groups[0])
  srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm)
  if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))):
    msg = "ppermute sources and destinations must be unique, got {}."
    raise ValueError(msg.format(perm))

  full_perm = []
  for grp in replica_groups:
    grp = list(sorted(grp))
    full_perm.extend((grp[src], grp[dst]) for src, dst in perm)
  return xops.CollectivePermute(x, full_perm)
Ejemplo n.º 23
0
 def process_primitive(self, primitive, tracers, params):
   assert not primitive.multiple_results  # TODO
   order = self.master.order              # pytype: disable=attribute-error
   primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
   series_in = [[zero_term] * order if s is zero_series else s
                for s in series_in]
   # TODO(mattjj): avoid always instantiating zeros
   series_in = [[np.zeros(np.shape(x), dtype=np.result_type(x))
                 if t is zero_term else t for t in series]
                for x, series in zip(primals_in, series_in)]
   rule = jet_rules[primitive]
   primal_out, terms_out = rule(primals_in, series_in, **params)
   return JetTracer(self, primal_out, terms_out)
Ejemplo n.º 24
0
Archivo: jet.py Proyecto: yangliuy/jax
 def process_call(self, call_primitive, f, tracers, params):
   primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
   primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
   f_jet, out_tree_def = traceable(jet_subtrace(f, self.master), in_tree_def)
   new_params = dict(params)
   if "donated_invars" in params:
     if any(params["donated_invars"]):
       raise ValueError("Buffer donation is not supported with jet.")
     new_donated_invars = (False,) * len(primals_and_series)
     new_params["donated_invars"] = new_donated_invars
   result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
   primals_out, series_out = tree_unflatten(out_tree_def(), result)
   return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
Ejemplo n.º 25
0
Archivo: jet.py Proyecto: nhanwei/jax
 def process_call(self, call_primitive, f, tracers, params):
     primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
     primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
     f_jet, out_tree_def = traceable(jet_subtrace(f, self.main),
                                     in_tree_def)
     update_params = call_param_updaters.get(call_primitive)
     new_params = (update_params(params, len(primals_and_series))
                   if update_params else params)
     result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
     primals_out, series_out = tree_unflatten(out_tree_def(), result)
     return [
         JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)
     ]
Ejemplo n.º 26
0
def _ppermute_translation_rule(c, x, device_groups, perm):
    group_size = len(device_groups[0])
    if not all(0 <= i < group_size and 0 <= j < group_size for i, j in perm):
        msg = (
            "ppermute permutation elements must take on values between 0 and "
            "the group size {}, but got {}.")
        raise ValueError(msg.format(group_size, perm))
    sources, dests = unzip2(perm)
    if not (len(sources) == len(set(sources))
            and len(dests) == len(set(dests))):
        msg = "ppermute sources and destinations must be unique, got {}."
        raise ValueError(msg.format(perm))

    full_perm = []
    for grp in device_groups:
        grp = list(sorted(grp))
        full_perm.extend((grp[src], grp[dst]) for src, dst in perm)
    return c.CollectivePermute(x, full_perm)
Ejemplo n.º 27
0
 def tree_update(i, grad_tree, opt_state):
   packed_state, tree, subtrees = opt_state
   grad_flat, tree2 = tree_flatten(grad_tree)
   if tree2 != tree:
     msg = ("optimizer update function was passed a gradient tree that did "
            "not match the parameter tree structure with which it was "
            "initialized: parameter tree {} and grad tree {}.")
     raise TypeError(msg.format(tree, tree2))
   states = map(tree_unflatten, subtrees, packed_state)
   new_states = map(partial(update, i), grad_flat, states)
   new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
   for subtree, subtree2 in zip(subtrees, subtrees2):
     if subtree2 != subtree:
       msg = ("optimizer update function produced an output structure that "
              "did not match its input structure: input {} and output {}.")
       raise TypeError(msg.format(subtree, subtree2))
   new_packed_state = pack(map(pack, new_states_flat))
   return OptimizerState(new_packed_state, tree, subtrees)
Ejemplo n.º 28
0
def pack_optimizer_state(marked_pytree):
    """Converts a marked pytree to an OptimizerState.

  The inverse of unpack_optimizer_state. Converts a marked pytree with the
  leaves of the outer pytree represented as JoinPoints back into an
  OptimizerState. This function is intended to be useful when deserializing
  optimizer states.

  Args:
    marked_pytree: A pytree containing JoinPoint leaves that hold more pytrees.
  Returns:
    An equivalent OptimizerState to the input argument.
  """
    sentinels, tree_def = tree_flatten(marked_pytree)
    assert all(isinstance(s, JoinPoint) for s in sentinels)
    subtrees = [s.subtree for s in sentinels]
    states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees))
    return OptimizerState(states_flat, tree_def, subtree_defs)
Ejemplo n.º 29
0
def _revise_cond_jaxpr(new_pval, old_pval, jaxpr, consts):
    new_pv, new_const = new_pval
    old_pv, old_const = old_pval
    if new_pv == old_pv:
        # we didn't move up the lattice by joining with the other side
        return jaxpr, consts
    elif old_pv is None:
        # we moved up the lattice from totally-known, so make a new jaxpr that
        # returns a single constant JaxTuple with elements that are constants
        # drawn from consts where new_pv is unknown
        assert not jaxpr.eqns and not consts
        outvar = pe.Var(0, "_cond")
        new_jaxpr = jaxpr.copy()
        new_jaxpr.constvars = [outvar]
        new_jaxpr.outvar = outvar
        new_consts = (core.pack([
            core.unit if pv is None else old_c
            for pv, old_c in zip(new_pv, old_const)
        ]), )
        return new_jaxpr, new_consts
    else:
        # we moved up the lattice, but not from totally-constant, so adapt the
        # japxr to return some new constants in places that are now unknown but
        # weren't before
        eqn = jaxpr.eqns[-1]
        assert eqn.primitive == core.pack_p
        assert len(eqn.outvars) == 1 and eqn.outvars[0] == jaxpr.outvar
        newvar = pe.gensym("_cond")
        new_constvars, new_constvals = unzip2([
            (newvar(), c) for new, old, c in zip(new_pv, old_pv, old_const)
            if old is None and new is not None
        ])
        new_consts = consts + tuple(new_constvals)
        new_jaxpr = jaxpr.copy()
        new_jaxpr.constvars = tuple(jaxpr.constvars) + tuple(new_constvars)
        newvars = iter(new_constvars)
        new_invars = [
            next(newvars) if old is None and new is not None else
            (core.unitvar if new is None and old is None else v)
            for new, old, v in zip(new_pv, old_pv, eqn.invars)
        ]
        new_jaxpr.eqns = (list(jaxpr.eqns[:-1]) +
                          [_pack_eqn(new_invars, jaxpr.outvar)])
        return new_jaxpr, new_consts
Ejemplo n.º 30
0
    def default_process_primitive(self, primitive, tracers, params):
        """Partially evaluate primitives and saves variable recipes."""
        pvs, consts = jax_util.unzip2(t.pval for t in tracers)
        if all(pv is None for pv in pvs):
            return primitive.bind(*consts, **params)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const, tracers)
        if any(not isinstance(t, UnzipTracer) for t in tracers):
            assert False
        key = all(t.is_key() for t in tracers)
        avals = [t.aval for t in tracers]
        ans = primitive.abstract_eval(*avals, **params)
        if not primitive.multiple_results:
            ans = [ans]
        out_tracers = [
            UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key)
            for aval in ans
        ]
        # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer
        eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params,
                                source_info_util.current())  # pytype: disable=wrong-arg-types
        for t in out_tracers:
            t.recipe = eqn

        is_variable = (key and primitive is harvest.sow_p
                       and params['tag'] == settings.tag)
        # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where
        # JaxprTrace will just return out_tracers, UnzipTrace will record an
        # additional VariableRecipe into the tracers, which will be used after
        # the trace is complete to construct init/apply Jaxprs.
        if is_variable:
            name, var_in_tracers, var_out_tracers = unzip_registry[primitive](
                tracers, out_tracers, **params)
            variable_recipe = VariableRecipe(name, var_in_tracers,
                                             var_out_tracers)
            for t in out_tracers:
                t.variable_recipe = variable_recipe

        if primitive.multiple_results:
            return out_tracers
        return out_tracers[0]