def _update_annotation(f: lu.WrappedFun, orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]], nonzeros: List[bool]) -> lu.WrappedFun: if orig_type is None: return f tan_types = [(aval.at_least_vspace(), keep) for nz, (aval, keep) in zip(nonzeros, orig_type) if nz] return lu.annotate(f, (*orig_type, *tan_types))
def _update_annotation( f: lu.WrappedFun, orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]], axis_size: int, axis_name: core.AxisName, in_dims: Sequence[Optional[int]] ) -> lu.WrappedFun: if orig_type is None: return f batched_in_type = [(core.unmapped_aval(axis_size, axis_name, dim, aval), keep) for dim, (aval, keep) in zip(in_dims, orig_type)] return lu.annotate(f, tuple(batched_in_type))
def _update_annotation( f: lu.WrappedFun, orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]], nonzeros: List[bool] ) -> lu.WrappedFun: if orig_type is None: return f # Implicit arguments never have tangents, so generate the tangent part of the # type annotation from explicit arguments only. orig_avals = [aval for aval, explicit in orig_type if explicit] tan_types = [(aval.at_least_vspace(), True) for nz, aval in zip(nonzeros, orig_avals) if nz] return lu.annotate(f, (*orig_type, *tan_types))
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) if not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) if update_params: params = update_params(params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) if config.jax_dynamic_shapes: in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args] fun = lu.annotate(fun, tuple(in_type)) out_flat = primitive.bind(fun, *all_args, **params) return tree_unflatten(out_tree(), out_flat)
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, always_lower: bool, keep_unused: bool, *arg_specs): """Lower into XLA. Args: always_lower: If `True`, even trivial programs (not doing any computation such as lambda x: x) will be lowered into an XLA program. keep_unused: If `False` (the default), arguments that JAX determines to be unused by `fun` *may* be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If `True`, unused arguments will not be pruned. """ if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " "got device={} and backend={}".format( device, backend)) abstract_args, arg_devices = util.unzip2(arg_specs) if fun.in_type is None: # Add an annotation inferred from the arguments; no dynamic axes here. in_type = tuple(unsafe_zip(abstract_args, itertools.repeat(True))) fun = lu.annotate(fun, in_type) else: assert all( map(core.typematch, abstract_args, [a for a, k in fun.in_type if k])) with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for jit in {elapsed_time} sec"): jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( fun, pe.debug_info_final(fun, "jit")) out_avals, kept_outputs = util.unzip2(out_type) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError("Encountered an unexpected tracer.") if config.jax_dynamic_shapes: keep_unused = True else: jaxpr = apply_outfeed_rewriter(jaxpr) if not keep_unused: jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr) consts = [c for i, c in enumerate(consts) if i in kept_const_idx] abstract_args, arg_devices = util.unzip2( [a for i, a in enumerate(arg_specs) if i in kept_var_idx]) donated_invars = [ x for i, x in enumerate(donated_invars) if i in kept_var_idx ] del kept_const_idx else: kept_var_idx = set(range(len(abstract_args))) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) backend = xb.get_device_backend(device) if device else xb.get_backend( backend) if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and not _backend_supports_unbounded_dynamic_shapes(backend)): jaxpr, consts = pe.pad_jaxpr(jaxpr, consts) map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr))) # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to evaluate their arguments. if not jaxpr.eqns and not always_lower and all(kept_outputs): return XlaComputation(name, None, True, None, None, None, jaxpr=jaxpr, consts=consts, device=device, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=False, ordered_effects=[], kept_var_idx=kept_var_idx, keepalive=None) if not _on_exit: log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG if len(abstract_args) > 10: msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args." else: msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}." logging.log(log_priority, msg) if nreps > 1: warnings.warn( f"The jitted function {name} includes a pmap. Using " "jit-of-pmap can lead to inefficient data movement, as the outer jit " "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/google/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): raise NotImplementedError( "jit of multi-host pmap not implemented (and jit-of-pmap can cause " "extra data movement anyway, so maybe you don't want it after all)." ) # pass long arg lists as tuple for TPU tuple_args = len(abstract_args) > 100 axis_env = xla.AxisEnv(nreps, (), ()) name_stack = util.new_name_stack(util.wrap_name(name, 'jit')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module_name = f"jit_{fun.__name__}" unordered_effects = [ eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in closed_jaxpr.effects if eff in core.ordered_effects ] module, keepalive = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, unordered_effects, ordered_effects, backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) return XlaComputation(name, module, False, donated_invars, fun.in_type, out_type, nreps=nreps, device=device, backend=backend, tuple_args=tuple_args, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=bool(unordered_effects), ordered_effects=ordered_effects, kept_var_idx=kept_var_idx, keepalive=keepalive)