def _matchaxis_symbolic_zeros(sz, name, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py if isinstance(x, Zero): if src == dst: return x elif type(src) == type(dst) == int: aval = core.mapped_aval(sz, src, x.aval) return Zero(core.unmapped_aval(sz, name, dst, aval)) elif src is not_mapped and dst is not not_mapped: return Zero(core.unmapped_aval(sz, name, dst, x.aval)) elif dst is not_mapped and sum_match: return Zero(core.mapped_aval(sz, src, x.aval)) else: raise ValueError((x, src, dst)) else: return matchaxis(sz, src, dst, x, sum_match=sum_match)
def _flatten_bwd(in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) py_cts_in = yield (py_res, py_cts_out), {} # For each None in py_cts_in, indicating an argument for which the rule # produces no cotangent, we replace it with a pytree with the structure of the # corresponding subtree of in_tree and with leaves of a non-pytree sentinel # object, to be replaced with Nones in the final returned result. zero = object() # non-pytree sentinel to replace Nones in py_cts_in dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves) cts_in_flat = [] append = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0])) or x try: if not isinstance(py_cts_in, tuple): raise ValueError tree_map(append, tuple(zero if ct is None else ct for ct in py_cts_in), dummy) except ValueError: _, in_tree2 = tree_flatten(py_cts_in) msg = ("Custom VJP rule must produce an output with the same container " "(pytree) structure as the args tuple of the primal function, " "and in particular must produce a tuple of length equal to the " "number of arguments to the primal function, but got VJP output " "structure {} for primal input structure {}.") raise TypeError(msg.format(in_tree2, in_tree)) from None # Ignore any None cotangents, and any corresponding to inputs for which the # type doesn't equal the tangent type (i.e. float0s) # TODO(mattjj): change this to check if tangent type represents 0dim vspace yield [Zero(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace() else ct for a, ct in zip(in_avals, cts_in_flat)]
def unmap_zero(zero, in_axis): return (zero if in_axis is None else Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))
def lift(self, val): tangent_zero = Zero(get_aval(val).at_least_vspace()) return JVPTracer(self, val, tangent_zero)
def read_cotangent(v): return ct_env.pop(v, Zero(v.aval))
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) def write_cotangent(prim, v, ct): # assert v not in primal_env assert ct is not Zero, (prim, v.aval) # check for an old harmless type error if ct is None or type(v) is Literal: return if type(ct) is Zero: # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return axes_to_reduce = tuple(axis_name for axis_name in reduce_axes if axis_name in core.get_aval(ct).named_shape and axis_name not in v.aval.named_shape) if axes_to_reduce: ct = jax.lax.psum(ct, axis_name=axes_to_reduce) ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape() assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, UndefinedPrimal(v.aval)) def write_primal(v, val): if not is_undefined_primal(val): primal_env[v] = val primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) ct_env: Dict[Any, Any] = {} ctx = (source_info_util.transform_name_stack('transpose') if transform_stack else contextlib.nullcontext()) with ctx: map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: # FIXME: Some invars correspond to tangents invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) else: cts_in, = map(read_cotangent, eqn.outvars) name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] params = dict(eqn.params) call_jaxpr = params.pop('call_jaxpr') cts_out = get_primitive_transpose(eqn.primitive)( params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes) elif eqn.primitive in reducing_transposes: cts_out = reducing_transposes[eqn.primitive]( reduce_axes, cts_in, *invals, **eqn.params) else: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out
def recast_to_float0(primal, tangent): if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0: return Zero(get_aval(primal).at_least_vspace()) else: return tangent
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) def write_cotangent(prim, v, ct): # assert v not in primal_env assert ct is not Zero, (prim, v.aval ) # check for an old harmless type error if ct is None or type(v) is Literal: return if type(ct) is Zero: # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join( v.aval, ct_aval).strip_weak_type().strip_named_shape() assert v.aval.strip_weak_type().strip_named_shape( ) == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, UndefinedPrimal(v.aval)) def write_primal(v, val): if not is_undefined_primal(val): primal_env[v] = val primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) ct_env: Dict[Any, Any] = {} map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: # FIXME: Some invars correspond to tangents invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) else: cts_in, = map(read_cotangent, eqn.outvars) with source_info_util.user_context(eqn.source_info): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] call_jaxpr, params = core.extract_call_jaxpr( eqn.primitive, eqn.params) cts_out = get_primitive_transpose(eqn.primitive)(params, call_jaxpr, invals, cts_in, cts_in_avals) else: cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out