def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr, consts: Sequence[XlaOp], *args: XlaOp) -> Sequence[XlaOp]: assert ctx.platform is not None def read(v): if type(v) is Literal: return pyval_to_ir_constants(ctx.builder, canonicalize_dtype(v.val)) else: return env[v] def aval(v): if type(v) is Literal: return abstractify(v.val) else: return v.aval def write(v, node): assert node is not None env[v] = node env: Dict[core.Var, Sequence[XlaOp]] = {} _partitionmap(write, [core.unitvar], pyval_to_ir_constants(ctx.builder, core.unit)) _partitionmap(write, jaxpr.constvars, consts) _partitionmap(write, jaxpr.invars, args) for eqn in jaxpr.eqns: if config.jax_experimental_name_stack: assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) source_info = eqn.source_info.replace( name_stack=ctx.name_stack + eqn.source_info.name_stack) else: source_info = eqn.source_info op_metadata = make_op_metadata( eqn.primitive, eqn.params, name_stack=ctx.name_stack, source_info=source_info) ctx.builder.set_op_metadata(op_metadata) in_nodes = _flatmap(read, eqn.invars) if (ctx.platform is not None and eqn.primitive in _backend_specific_translations[ctx.platform]): rule = _backend_specific_translations[ctx.platform][eqn.primitive] elif eqn.primitive in _translations: rule = _translations[eqn.primitive] else: raise NotImplementedError( f"XLA translation rule for primitive '{eqn.primitive.name}' not found") with source_info_util.user_context(eqn.source_info.traceback): eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if config.jax_experimental_name_stack else ctx) ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars), *in_nodes, **eqn.params) assert isinstance(ans, collections.abc.Sequence), (ans, eqn) assert all(isinstance(x, xe.XlaOp) for x in ans), (ans, eqn) map(ctx.builder.get_shape, ans) # force xla to do shape error checking ctx.builder.clear_op_metadata() _partitionmap(write, eqn.outvars, ans) return _flatmap(read, jaxpr.outvars)
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) def write_cotangent(prim, v, ct): # assert v not in primal_env assert ct is not Zero, (prim, v.aval ) # check for an old harmless type error if ct is None or type(v) is Literal: return if type(ct) is Zero: # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if not core.skip_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type() assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.get(v, Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, UndefinedPrimal(v.aval)) def write_primal(v, val): if not is_undefined_primal(val): primal_env[v] = val primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) # Find the last use of each cotangent so that they can be removed # as soon as possible. drop_cts: List[Set[Any]] = [] seen_vars: Set[Any] = set(jaxpr.invars) for eqn in jaxpr.eqns: read_set = set(eqn.outvars) # NOTE: eqn is not transposed yet! drop_cts.append(read_set - seen_vars) seen_vars |= read_set ct_env: Dict[Any, Any] = {} map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn, to_drop in zip(jaxpr.eqns[::-1], drop_cts[::-1]): # FIXME: Some invars correspond to tangents invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) else: cts_in, = map(read_cotangent, eqn.outvars) with source_info_util.user_context(eqn.source_info): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] call_jaxpr, params = core.extract_call_jaxpr( eqn.primitive, eqn.params) cts_out = get_primitive_transpose(eqn.primitive)(params, call_jaxpr, invals, cts_in, cts_in_avals) else: cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) for var in to_drop: ct_env.pop(var, None) # NB: Constant cotangents might be missing cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) def write_cotangent(prim, v, ct): # assert v not in primal_env assert ct is not Zero, (prim, v.aval ) # check for an old harmless type error if ct is None or type(v) is Literal: return if type(ct) is Zero: # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type() assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, UndefinedPrimal(v.aval)) def write_primal(v, val): if not is_undefined_primal(val): primal_env[v] = val primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) ct_env: Dict[Any, Any] = {} map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: # FIXME: Some invars correspond to tangents invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) else: cts_in, = map(read_cotangent, eqn.outvars) with source_info_util.user_context(eqn.source_info): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] call_jaxpr, params = core.extract_call_jaxpr( eqn.primitive, eqn.params) cts_out = get_primitive_transpose(eqn.primitive)(params, call_jaxpr, invals, cts_in, cts_in_avals) else: cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) def write_cotangent(prim, v, ct): # assert v not in primal_env assert ct is not Zero, (prim, v.aval) # check for an old harmless type error if ct is None or type(v) is Literal: return if type(ct) is Zero: # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return axes_to_reduce = tuple(axis_name for axis_name in reduce_axes if axis_name in core.get_aval(ct).named_shape and axis_name not in v.aval.named_shape) if axes_to_reduce: ct = jax.lax.psum(ct, axis_name=axes_to_reduce) ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape() assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, UndefinedPrimal(v.aval)) def write_primal(v, val): if not is_undefined_primal(val): primal_env[v] = val primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) ct_env: Dict[Any, Any] = {} ctx = (source_info_util.transform_name_stack('transpose') if transform_stack else contextlib.nullcontext()) with ctx: map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: # FIXME: Some invars correspond to tangents invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) else: cts_in, = map(read_cotangent, eqn.outvars) name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] params = dict(eqn.params) call_jaxpr = params.pop('call_jaxpr') cts_out = get_primitive_transpose(eqn.primitive)( params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes) elif eqn.primitive in reducing_transposes: cts_out = reducing_transposes[eqn.primitive]( reduce_axes, cts_in, *invals, **eqn.params) else: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, consts: Sequence[Sequence[ir.Value]], *args: Sequence[ir.Value]) -> Sequence[Sequence[ir.Value]]: """Lowers a jaxpr into mHLO, inlined into an existing function. Assumes that an MLIR context, location, and insertion point are set. """ def read(v: core.Var) -> Sequence[ir.Value]: if type(v) is core.Literal: return ir_constants(v.val, canonicalize_types=True) else: return env[v] def aval(v: core.Var) -> core.AbstractValue: if type(v) is core.Literal: return xla.abstractify(v.val) else: return v.aval def write(v: core.Var, node: Sequence[ir.Value]): assert node is not None env[v] = tuple(node) env: Dict[core.Var, Tuple[ir.Value, ...]] = {} assert len(args) == len(jaxpr.invars), (jaxpr, args) assert len(consts) == len(jaxpr.constvars), (jaxpr, consts) assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts write(core.unitvar, ()) map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: in_nodes = map(read, eqn.invars) loc = _source_info_to_location(eqn.primitive, eqn.params, eqn.source_info, name_stack=ctx.name_stack) with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in _platform_specific_lowerings[ctx.platform]: rule = _platform_specific_lowerings[ctx.platform][ eqn.primitive] elif eqn.primitive in xla._backend_specific_translations[ ctx.platform]: rule = xla_fallback_lowering(eqn.primitive) elif eqn.primitive in _lowerings: rule = _lowerings[eqn.primitive] elif eqn.primitive in xla._translations: rule = xla_fallback_lowering(eqn.primitive) else: raise NotImplementedError( f"MLIR translation rule for primitive '{eqn.primitive.name}' not " f"found for platform {ctx.platform}") rule_ctx = LoweringRuleContext(module_context=ctx, primitive=eqn.primitive, avals_in=map(aval, eqn.invars), avals_out=map(aval, eqn.outvars)) ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes), **eqn.params) try: out_nodes = tuple(map(wrap_singleton_ir_values, ans)) except TypeError as e: raise ValueError("Output of translation rule must be iterable: " f"{eqn}, got output {ans}") from e assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn) assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (ans, eqn) assert len(ans) == len(eqn.outvars), (ans, eqn) map(write, eqn.outvars, out_nodes) return map(read, jaxpr.outvars)
def jaxpr_subcomp(ctx: LoweringContext, jaxpr: core.Jaxpr, consts: Sequence[Sequence[ir.Value]], *args: Sequence[ir.Value]) -> Sequence[Sequence[ir.Value]]: """Lowers a jaxpr into mHLO, inlined into an existing function. Assumes that an MLIR context, location, and insertion point are set. """ def read(v): if type(v) is core.Literal: return ir_constants(v.val, canonicalize_types=True) else: return env[v] def aval(v): if type(v) is core.Literal: return xla.abstractify(v.val) else: return v.aval def write(v, node): assert node is not None env[v] = tuple(node) env: Dict[core.Var, Tuple[ir.Value]] = {} assert len(args) == len(jaxpr.invars), (jaxpr, args) assert len(consts) == len(jaxpr.constvars), (jaxpr, consts) write(core.unitvar, ()) map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: in_nodes = map(read, eqn.invars) # TODO(phawkins): attach the primitive name, parameters, and name stack as # metadata. loc = _source_info_to_location(eqn.source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in _platform_specific_lowerings[ctx.platform]: rule = _platform_specific_lowerings[ctx.platform][ eqn.primitive] elif eqn.primitive in _lowerings: rule = _lowerings[eqn.primitive] elif eqn.primitive in xla._translations: rule = partial(xla_fallback_lowering, eqn.primitive) else: raise NotImplementedError( f"MLIR translation rule for primitive '{eqn.primitive.name}' not " "found") ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars), *map(_unwrap_singleton_ir_values, in_nodes), **eqn.params) try: out_nodes = tuple(map(wrap_singleton_ir_values, ans)) except TypeError as e: raise ValueError("Output of translation rule must be iterable: " f"{eqn}") from e assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn) assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (ans, eqn) assert len(ans) == len(eqn.outvars), (ans, eqn) map(write, eqn.outvars, out_nodes) return map(read, jaxpr.outvars)