Esempio n. 1
0
File: ad.py Progetto: jbampton/jax
def _primal_tangent_shapes_match(primal, tangent):
  if type(tangent) is not Zero:
    primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
    tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False)
    assert primal_aval.shape == tangent_aval.shape, (primal_aval.shape, tangent_aval.shape)
    expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
    assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
Esempio n. 2
0
 def fix_float0(arg_jax, ct_arg_jax):
     arg_dtype = dtypes.result_type(arg_jax)  # May be scalar
     ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
     if ct_arg_dtype != ct_arg_jax.dtype:
         return ad_util.zeros_like_aval(
             core.ShapedArray(np.shape(arg_jax), ct_arg_dtype))
     return ct_arg_jax
Esempio n. 3
0
File: ad.py Progetto: jbampton/jax
def recast_to_float0(primal, tangent):
  if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
    return Zero(get_aval(primal).at_least_vspace())
  else:
    return tangent