def _instantiate_zeros(arg, tan): """Turn special ad.zero tangents into arrays of 0s.""" if type(tan) is not ad.Zero: return tan try: aval = arg.aval return ad.instantiate_zeros_aval(aval, tan) except (AttributeError, KeyError): # We get here for regular Python values return ad.zeros_like_jaxval(arg)
def _convert_zeros(convert_symbolic, example, tangent): if tangent is ad.zero: if not convert_symbolic: return core.unit else: return ad.zeros_like_jaxval(example) elif type(tangent) is ad.TangentTuple: return core.pack( map(_convert_zeros, convert_symbolic, example, tangent)) else: return tangent