def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False) assert primal_aval.shape == tangent_aval.shape, (primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
def fix_float0(arg_jax, ct_arg_jax): arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: return ad_util.zeros_like_aval( core.ShapedArray(np.shape(arg_jax), ct_arg_dtype)) return ct_arg_jax
def recast_to_float0(primal, tangent): if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0: return Zero(get_aval(primal).at_least_vspace()) else: return tangent