Example #1
0
 def start_subtrace(self):
   """Starts a nested trace, returns the Trace object."""
   # TODO: This follows the __enter__ part of core.new_main.
   if config.omnistaging_enabled:
     level = core.thread_local_state.trace_state.trace_stack.next_level()
     main = core.MainTrace(level, pe.JaxprTrace)
     core.thread_local_state.trace_state.trace_stack.push(main)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel())
   else:
     level = core.thread_local_state.trace_state.trace_stack.next_level(False)
     main = core.MainTrace(level, pe.JaxprTrace)
     core.thread_local_state.trace_state.trace_stack.push(main, False)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel())
Example #2
0
def callback_subtrace(master, *in_vals, **params):
    trace = CallbackTrace(master, core.cur_sublevel())
    in_tracers = [CallbackTracer(trace, val) for val in in_vals]
    outs = yield in_tracers, params
    out_tracers = map(trace.full_raise, outs)
    out_vals = [t.val for t in out_tracers]
    yield out_vals
Example #3
0
File: jet.py Project: nhanwei/jax
def jet_subtrace(main, primals, series):
    trace = JetTrace(main, core.cur_sublevel())
    in_tracers = map(partial(JetTracer, trace), primals, series)
    ans = yield in_tracers, {}
    out_tracers = map(trace.full_raise, ans)
    out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
    yield out_primals, out_terms
Example #4
0
 def start_subtrace():
     """Starts a nested trace, returns the Trace object."""
     # TODO: This follows the __enter__ part of core.new_master. share
     level = core.trace_state.trace_stack.next_level(False)
     master = core.MasterTrace(level, pe.JaxprTrace)
     core.trace_state.trace_stack.push(master, False)
     return pe.JaxprTrace(master, core.cur_sublevel())
Example #5
0
def _axis_index_bind(*, axis_name):
    frame = core.axis_frame(axis_name)
    if frame.main_trace is not None:
        trace = frame.main_trace.trace_type(frame.main_trace,
                                            core.cur_sublevel())
        return trace.process_axis_index(frame)
    return core.Primitive.bind(axis_index_p, axis_name=axis_name)
Example #6
0
 def current(cls):
   # TODO(tomhennigan): Remove once a version of JAX is released incl PR#9423.
   trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack
   top_type = trace_stack[0].trace_type
   level = trace_stack[-1].level
   sublevel = jax_core.cur_sublevel()
   return JaxTraceLevel(opaque=(top_type, level, sublevel))
Example #7
0
def _interpret_subtrace(master: core.MasterTrace, *in_vals: TfValOrUnit):
  trace = TensorFlowTrace(master, core.cur_sublevel())
  in_tracers = tuple(TensorFlowTracer(trace, val) for val in in_vals)
  outs = yield in_tracers, {}  # type: Sequence[TfValOrUnit]
  out_tracers: Iterable[TensorFlowTracer] = map(trace.full_raise, outs)  # type: ignore
  out_vals: Sequence[TfValOrUnit] = tuple(t.val for t in out_tracers)
  yield out_vals
Example #8
0
def _interpret_subtrace(master: core.MasterTrace, *in_vals: TfVal):
    trace = TensorFlowTrace(master, core.cur_sublevel())
    in_tracers = [TensorFlowTracer(trace, val) for val in in_vals]
    outs = yield in_tracers, {}
    out_tracers = map(trace.full_raise, outs)
    out_vals = [t.val for t in out_tracers]
    yield out_vals
Example #9
0
def apply_transform(net_params, inputs):
    with jc.new_master(ApplyTrace) as master:
        trace = ApplyTrace(master, jc.cur_sublevel())
        ans = yield map(partial(ApplyTracer, trace, net_params), inputs), {}
        out_tracer = trace.full_raise(ans)
        out_val = out_tracer.val
        del master, out_tracer
    yield out_val
Example #10
0
def doubling_subtrace(main, heads, tails):
  trace = DoublingTrace(main, core.cur_sublevel())
  in_tracers = [DoublingTracer(trace, h, t) if t is not None else h
                for h, t in zip(heads, tails)]
  ans = yield in_tracers, {}
  out_tracers = map(trace.full_raise, ans)
  yield unzip2([(out_tracer.head, out_tracer.tail)
                for out_tracer in out_tracers])
Example #11
0
File: jet.py Project: romanodev/jax
def jet_transform(primals, series):
  with core.new_master(JetTrace) as master:
    trace = JetTrace(master, core.cur_sublevel())
    in_tracers = map(partial(JetTracer, trace), primals, series)
    ans = yield in_tracers, {}
    out_tracers = map(trace.full_raise, ans)
    out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
  yield out_primals, out_terms
Example #12
0
 def start_subtrace(self):
     """Starts a nested trace, returns the Trace object."""
     # TODO: This follows the __enter__ part of core.new_main.
     level = core.thread_local_state.trace_state.trace_stack.next_level()
     name_stack = source_info_util.current_name_stack()
     main = core.MainTrace(level, pe.JaxprTrace, name_stack=name_stack)
     core.thread_local_state.trace_state.trace_stack.push(main)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel(), name_stack=name_stack)
Example #13
0
 def start_subtrace(self):
     """Starts a nested trace, returns the Trace object."""
     # TODO: This follows the __enter__ part of core.new_master.
     level = core.thread_local_state.trace_state.trace_stack.next_level(
         False)
     master = core.MasterTrace(level, pe.JaxprTrace)
     core.thread_local_state.trace_state.trace_stack.push(master, False)
     self._count_subtraces += 1
     return pe.JaxprTrace(master, core.cur_sublevel())
Example #14
0
 def _apply_subtrace(master, submodule_params, *vals):
     submodule_params = submodule_params.val
     trace = ApplyTrace(master, jc.cur_sublevel())
     outs = yield map(
         partial(ApplyTrace.Tracer, trace,
                 ApplyTrace.SubmoduleParamsIterator(submodule_params)),
         vals), {}
     out_tracers = map(trace.full_raise, outs)
     yield [t.val for t in out_tracers]
Example #15
0
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
                              in_avals: Sequence[core.AbstractValue]):
  frame = DJaxprStackFrame()
  with pe.extend_jaxpr_stack(main, frame):
    trace = DJaxprTrace(main, core.cur_sublevel())
    in_dim_tracers, in_avals = _place_in_dim_tracers_in_shapes(trace, in_avals)
    in_tracers = map(trace.new_arg, in_avals)
    ans = fun.call_wrapped(*in_tracers)
    out_tracers = map(trace.full_raise, ans)
  out_dim_tracers = _extract_out_dim_tracers_from_shapes(main, in_dim_tracers, out_tracers)
  return frame.to_jaxpr(in_dim_tracers, in_tracers, out_dim_tracers, out_tracers)
Example #16
0
File: ad.py Project: jbampton/jax
def jvp_subtrace(main, primals, tangents):
  trace = JVPTrace(main, core.cur_sublevel())
  for x in list(primals) + list(tangents):
    if isinstance(x, Tracer):
      assert x._trace.level < trace.level
  in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x
                for x, t in zip(primals, tangents)]
  ans = yield in_tracers, {}
  out_tracers = map(trace.full_raise, ans)
  yield unzip2([(out_tracer.primal, out_tracer.tangent)
                for out_tracer in out_tracers])
Example #17
0
def doubling_transform(*args):
    with core.new_main(DoublingTrace) as main:
        trace = DoublingTrace(main, core.cur_sublevel())
        in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args]
        outputs = yield in_tracers, {}
        if isinstance(outputs, Sequence):
            out_tracers = map(trace.full_raise, outputs)
            result = [(x.head, x.tail) for x in out_tracers]
        else:
            out_tracer = trace.full_raise(outputs)
            result = (out_tracer.head, out_tracer.tail)
    yield result
Example #18
0
def plant_function(main: jax_core.MainTrace, settings: HarvestSettings,
                   in_tree: Any, args: Iterable[Any]):
    """A function transformation that injects values in place of sows."""
    trace = HarvestTrace(main, jax_core.cur_sublevel())
    plants, args = tree_util.tree_unflatten(in_tree, args)
    args = jax_util.safe_map(trace.pure, args)
    context = PlantContext(settings, plants)
    with trace_util.new_dynamic_context(main, context):
        ans = yield args, {}
        out_tracers = jax_util.safe_map(trace.full_raise, ans)
        del main
    yield [t.val for t in out_tracers]
Example #19
0
File: ad.py Project: jbampton/jax
def jvp_subtrace_aux(main, primals, tangents):
  trace = JVPTrace(main, core.cur_sublevel())
  for x in list(primals) + list(tangents):
    if isinstance(x, Tracer):
      assert x._trace.level < trace.level
  ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
  ans_tracers = map(trace.full_raise, ans)
  out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
  aux_primals = [core.full_lower(x.primal)
                 if isinstance(x, JVPTracer) and x._trace.level == trace.level
                 else x for x in aux]
  yield (out_primals, out_tangents), aux_primals
Example #20
0
def _top_trace(filter_type=Trace):
    """Needed when parametrized function has no arguments provided,
    so it cannot retrieve the trace from its input tracers."""

    traces = [trace for trace in thread_local_state.trace_state.trace_stack.stack if
              issubclass(trace.trace_type, filter_type)]

    if len(traces) == 0:
        return None

    master = traces[-1]
    return master.trace_type(master, cur_sublevel())
Example #21
0
def mask_subtrace(main, shapes, padded_env, *in_vals):
  env_keys, _ = padded_env
  logical_env_vals, in_vals = in_vals[:len(env_keys)], in_vals[len(env_keys):]
  logical_env = dict(zip(env_keys, logical_env_vals))
  padded_env = dict(zip(*padded_env))
  trace = MaskTrace(main, core.cur_sublevel())
  in_tracers = [MaskTracer(trace, x, s).full_lower()
                for x, s in zip(in_vals, shapes)]
  with extend_shape_envs(logical_env, padded_env):
    outs = yield in_tracers, {}
  out_tracers = map(trace.full_raise, outs)
  out_vals, out_shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
  yield out_vals, out_shapes
Example #22
0
def _init_transform(key, *inputs):
    """Transforms a flattened `parametrized` function
    into its corresponding `init_parameters` function."""
    init_trace = _top_trace(filter_type=InitTrace)
    with new_main(InitTrace) as master:
        global_parameters_dict = init_trace.state.global_parameters_dict if init_trace else {}
        random_state = init_trace.state.random_state if init_trace else RandomState(key)
        master.state = InitTraceState(random_state, global_parameters_dict)
        trace = InitTrace(master, cur_sublevel())
        outs = yield map(trace.full_raise, inputs), {}
        outs = trace.lower_all(outs)
        parameters_dict = master.state.parameters_dict
        del master
    yield outs, parameters_dict
Example #23
0
    def wrapped(*args, **kwargs):
        """Callable returned by unzip."""
        with jax_core.new_master(UnzipTrace) as master:
            # Preparing args to be traced
            trace = UnzipTrace(master, jax_core.cur_sublevel())
            fun = lu.wrap_init(f, kwargs)
            avals = tree_util.tree_map(trace_util.get_shaped_aval, args)
            flat_avals, flat_keys, in_tree = (flatten_args_into_keys(
                avals, key_args))
            flat_pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals]
            flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)

            # Trace to jaxpr
            settings = UnzipSettings(tag, False)
            fun = unzip_to_init_apply_subjaxprs(flat_fun, trace, settings)  # pylint: disable=no-value-for-parameter
            success, results = fun.call_wrapped(flat_keys, flat_pvals)
            if not success:
                raise ValueError('Variables do not cut dependence graph.')
            init_out, apply_out, _, metadata = results
            init_jaxpr, init_consts, init_env = init_out
            assert not init_env

            apply_jaxpr, apply_consts, apply_env = apply_out
            assert not apply_env

            names, variable_tree, _ = metadata
            out_tree = out_tree()

            # Final functions
            def init(*args):
                flat_args, _ = tree_util.tree_flatten(args)
                flat_params = jax_core.eval_jaxpr(init_jaxpr, init_consts,
                                                  *flat_args)
                flat_variables = tree_util.tree_unflatten(
                    variable_tree, flat_params)
                return {
                    name: var
                    for name, var in safe_zip(names, flat_variables)
                }

            def apply(params, *args):
                flat_variables, _ = tree_util.tree_flatten(
                    [params[name] for name in names])
                flat_args, _ = tree_util.tree_flatten(args)
                out = jax_core.eval_jaxpr(apply_jaxpr, apply_consts,
                                          *(flat_variables + flat_args))
                return tree_util.tree_unflatten(out_tree, out)

            del master
        return init, apply
Example #24
0
def jvp_subtrace(main, primals, tangents):
  trace = JVPTrace(main, core.cur_sublevel())
  for x in list(primals) + list(tangents):
    if isinstance(x, Tracer):
      if x._trace.level >= trace.level:
        raise core.escaped_tracer_error(
            x, f"Tracer from a higher level: {x} in trace {trace}")
      assert x._trace.level < trace.level
  in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x
                for x, t in zip(primals, tangents)]
  ans = yield in_tracers, {}
  out_tracers = map(trace.full_raise, ans)
  yield unzip2([(out_tracer.primal, out_tracer.tangent)
                for out_tracer in out_tracers])
Example #25
0
def _axis_index_bind(*, axis_name):
    if not isinstance(axis_name, (tuple, list)):
        axis_name = (axis_name, )
    inner_size = 1
    index = 0
    for name in reversed(axis_name):
        frame = core.axis_frame(name)
        if frame.main_trace is not None:
            trace = frame.main_trace.trace_type(frame.main_trace,
                                                core.cur_sublevel())
            name_idx = trace.process_axis_index(frame)
        else:
            name_idx = core.Primitive.bind(axis_index_p, axis_name=name)
        index += name_idx * inner_size
        inner_size *= psum(1, name)
    return index
Example #26
0
def reap_function(main: jax_core.MainTrace, settings: HarvestSettings,
                  return_metadata: bool, args: Iterable[Any]):
    """A function transformation that returns reap values."""
    trace = HarvestTrace(main, jax_core.cur_sublevel())
    in_tracers = jax_util.safe_map(trace.pure, args)
    context = ReapContext(settings, {})
    with trace_util.new_dynamic_context(main, context):
        ans = yield in_tracers, {}
        out_tracers = jax_util.safe_map(trace.full_raise, ans)
        reap_tracers = tree_util.tree_map(lambda x: trace.full_raise(x.value),
                                          context.reaps)
        reap_metadata = tree_util.tree_map(lambda x: x.metadata, context.reaps)
        del main
    out_values, reap_values = tree_util.tree_map(lambda x: x.val,
                                                 (out_tracers, reap_tracers))
    if return_metadata:
        out = (out_values, reap_values, reap_metadata)
    else:
        out = (out_values, reap_values)
    yield out
Example #27
0
def harvest_function(master: jax_core.MainTrace, settings: HarvestSettings,
                     in_tree, args: Iterable[Any]):
    """A JAX linear_util transformation that runs a HarvestTrace."""
    trace = HarvestTrace(master, jax_core.cur_sublevel())
    plants, args = tree_util.tree_unflatten(in_tree, args)
    in_tracers = safe_map(trace.pure, args)
    context = HarvestContext(settings, {}, plants)
    with trace_util.new_dynamic_context(master, context):
        ans = yield in_tracers, {}
        out_tracers = safe_map(trace.full_raise, ans)
        reaps = tree_util.tree_map(trace.full_raise, context.reaps)
        del master
    reaped_tracers = {}
    for key, reaped_tracer in reaps.items():
        if isinstance(reaped_tracer, HarvestList):
            reaped_tracers[key] = reaped_tracer.as_array()
        else:
            reaped_tracers[key] = reaped_tracer
    yield ([t.val for t in out_tracers],
           tree_util.tree_map(lambda t: t.val, reaped_tracers))
Example #28
0
File: jet.py Project: nhanwei/jax
 def todo(x):
     primals, series = tree_unflatten(treedef, x)
     trace = JetTrace(main, core.cur_sublevel())
     return map(partial(JetTracer, trace), primals, series)
Example #29
0
 def todo(x):
     trace = HarvestTrace(master, jax_core.cur_sublevel())
     return jax_util.safe_map(functools.partial(HarvestTracer, trace),
                              x)
Example #30
0
File: ad.py Project: jbampton/jax
 def todo(x):
   primals, tangents = tree_unflatten(treedef, x)
   trace = JVPTrace(main, core.cur_sublevel())
   return map(partial(JVPTracer, trace), primals, tangents)