def test_device_put_on_python_scalars(self): device = jax.devices()[0] int_type = dtypes.canonicalize_dtype(np.int64) float_type = dtypes.canonicalize_dtype(np.float64) complex_type = dtypes.canonicalize_dtype(np.complex128) # int res = _cpp_device_put(1, device).to_py() self.assertEqual(res, 1) self.assertEqual(res.dtype, int_type) # We also compare to the Python Jax API, to make sure we have the exact # same behavior. When Jax removes the flag and removes this feature, this # test will fail. self.assertEqual(jnp.asarray(1).dtype, res.dtype) # float res = _cpp_device_put(1.0, device).to_py() self.assertEqual(res, 1.0) self.assertEqual(res.dtype, float_type) self.assertEqual(jnp.asarray(1.0).dtype, res.dtype) # bool for bool_value in [True, False]: res = _cpp_device_put(bool_value, device).to_py() self.assertEqual(res, np.asarray(bool_value)) self.assertEqual(res.dtype, np.bool_) self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype) # Complex res = _cpp_device_put(1 + 1j, device).to_py() self.assertEqual(res, 1 + 1j) self.assertEqual(res.dtype, complex_type) self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
def test_ravel_pytree(pytree): flat, unravel_fn = ravel_pytree(pytree) unravel = unravel_fn(flat) tree_flatten(tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree)) assert all(tree_flatten(tree_multimap(lambda x, y: canonicalize_dtype(lax.dtype(x)) == canonicalize_dtype(lax.dtype(y)), unravel, pytree))[0])
def test_arg_signature_of_value(self): """Tests the C++ code-path.""" jax_enable_x64 = config.x64_enabled # 1. Numpy scalar types for dtype in _SCALAR_NUMPY_TYPES: value = dtype(0) signature = jaxlib.jax_jit._ArgSignatureOfValue( value, jax_enable_x64) self.assertEqual(signature.dtype, jax.device_put(value).dtype) self.assertEqual(signature.shape, ()) self.assertFalse(signature.weak_type) # 2. Numpy arrays for dtype in _SCALAR_NUMPY_TYPES: value = np.zeros((3, 4), dtype=dtype) signature = jaxlib.jax_jit._ArgSignatureOfValue( value, jax_enable_x64) self.assertEqual(signature.dtype, jax.device_put(value).dtype) self.assertEqual(signature.shape, (3, 4)) self.assertFalse(signature.weak_type) int_type = dtypes.canonicalize_dtype(np.int64) float_type = dtypes.canonicalize_dtype(np.float64) complex_type = dtypes.canonicalize_dtype(np.complex128) # 3. Python scalar types # int signature = jaxlib.jax_jit._ArgSignatureOfValue(1, jax_enable_x64) self.assertEqual(signature.dtype, jax.device_put(1).dtype) self.assertEqual(signature.dtype, int_type) self.assertEqual(signature.shape, ()) self.assertTrue(signature.weak_type) # float signature = jaxlib.jax_jit._ArgSignatureOfValue(1.0, jax_enable_x64) self.assertEqual(signature.dtype, jax.device_put(1.0).dtype) self.assertEqual(signature.dtype, float_type) self.assertEqual(signature.shape, ()) self.assertTrue(signature.weak_type) # bool for bool_value in [True, False]: signature = jaxlib.jax_jit._ArgSignatureOfValue( bool_value, jax_enable_x64) self.assertEqual(signature.dtype, jax.device_put(bool_value).dtype) self.assertEqual(signature.dtype, np.bool_) self.assertEqual(signature.shape, ()) self.assertTrue(signature.weak_type) # Complex signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64) self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype) self.assertEqual(signature.dtype, complex_type) self.assertEqual(signature.shape, ()) self.assertTrue(signature.weak_type)
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): """Compares dtypes across JAX and TF dtypes. Overrides super method.""" def to_numpy_dtype(dt): return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype if not config.FLAGS.jax_enable_x64 and canonicalize_dtypes: self.assertEqual(dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))), dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y)))) else: self.assertEqual(to_numpy_dtype(jtu._dtype(x)), to_numpy_dtype(jtu._dtype(y)))
def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) if len(shape) not in [3, 4, 5]: raise ValueError( "Delta orthogonal initializer requires a 3D, 4D or 5D " "shape.") if shape[-1] < shape[-2]: raise ValueError("`fan_in` must be less or equal than `fan_out`. ") ortho_init = orthogonal(scale=scale, column_axis=column_axis, dtype=dtype) ortho_matrix = ortho_init(key, shape[-2:]) W = jnp.zeros(shape, dtype=dtype) if len(shape) == 3: k = shape[0] return ops.index_update(W, ops.index[(k - 1) // 2, ...], ortho_matrix) elif len(shape) == 4: k1, k2 = shape[:2] return ops.index_update( W, ops.index[(k1 - 1) // 2, (k2 - 1) // 2, ...], ortho_matrix) else: k1, k2, k3 = shape[:3] return ops.index_update( W, ops.index[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2, ...], ortho_matrix)
def test_canonicalize_type(self): expected = { True: _EXPECTED_CANONICALIZE_X64, False: _EXPECTED_CANONICALIZE_X32, } for in_dtype, expected_dtype in expected[FLAGS.jax_enable_x64].items(): self.assertEqual(dtypes.canonicalize_dtype(in_dtype), expected_dtype)
def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype, order, mode, cval, impl, round_, rng_factory): def args_maker(): x = np.arange(prod(shape), dtype=dtype).reshape(shape) coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape] if round_: coords = [c.round().astype(int) for c in coords] return x, coords rng = rng_factory(self.rng()) lsp_op = lambda x, c: lsp_ndimage.map_coordinates( x, c, order=order, mode=mode, cval=cval) impl_fun = (osp_ndimage.map_coordinates if impl == "original" else _fixed_ref_map_coordinates) osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval) if dtype in float_dtypes: epsilon = max([ dtypes.finfo(dtypes.canonicalize_dtype(d)).eps for d in [dtype, coords_dtype] ]) self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=100 * epsilon) else: self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)
def __init__(self, v=0., log_density=0., event_dim=0, validate_args=None, value=None): if value is not None: v = value warnings.warn( "`value` argument has been deprecated in favor of `v` argument.", FutureWarning) if event_dim > jnp.ndim(v): raise ValueError( 'Expected event_dim <= v.dim(), actual {} vs {}'.format( event_dim, jnp.ndim(v))) batch_dim = jnp.ndim(v) - event_dim batch_shape = jnp.shape(v)[:batch_dim] event_shape = jnp.shape(v)[batch_dim:] self.v = lax.convert_element_type(v, canonicalize_dtype(jnp.float64)) # NB: following Pyro implementation, log_density should be broadcasted to batch_shape self.log_density = promote_shapes(log_density, shape=batch_shape)[0] super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def test_gradients_with_custom_jvp(self, with_function=True): """Check gradients, for a function with custom JVP.""" @jax.custom_jvp def f(x): return x * x @f.defjvp def f_jvp(primals, tangents): # 3 * x * x_t x, = primals x_dot, = tangents primal_out = f(x) tangent_out = 3. * x * x_dot return primal_out, tangent_out self.assertAllClose(4. * 4., f(4.)) self.assertAllClose(3. * 4., jax.grad(f)(4.)) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) self.assertAllClose(4. * 4., f_tf(jnp.float_(4.))) x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_)) with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x))
def testBinaryPromotion(self, swap, jit): testcases = [ (jnp.array(1.), 0., jnp.float_), (jnp.array(1.), jnp.array(0.), jnp.float_), (jnp.array(1.), jnp.array(0., dtype=jnp.float16), jnp.float_), (jnp.array(1.), jnp.array(0., dtype=jnp.float32), jnp.float_), (jnp.array(1.), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array(1., dtype=jnp.float16), 0., jnp.float16), (jnp.array(1., dtype=jnp.float32), 0., jnp.float32), (jnp.array(1., dtype=jnp.float64), 0., jnp.float64), (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float16), jnp.float16), (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array(1., dtype=jnp.float64), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array([1.]), 0., jnp.float_), (jnp.array([1.]), jnp.array(0.), jnp.float_), (jnp.array([1.]), jnp.array(0., dtype=jnp.float16), jnp.float_), (jnp.array([1.]), jnp.array(0., dtype=jnp.float32), jnp.float_), (jnp.array([1.]), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array([1.], dtype=jnp.float32), jnp.array(0., dtype=jnp.float16), jnp.float32), (jnp.array([1.], dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array([1.], dtype=jnp.float16), 0., jnp.float16), ] op = jax.jit(operator.add) if jit else operator.add for x, y, dtype in testcases: x, y = (y, x) if swap else (x, y) z = x + y self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z)) self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))
def _tolerance(dtype: onp.dtype, tol: Optional[float] = None) -> float: tol = {} if tol is None else tol if not isinstance(tol, dict): return tol tol = {onp.dtype(key): value for key, value in tol.items()} dtype = _dtypes.canonicalize_dtype(onp.dtype(dtype)) return tol.get(dtype, _default_tolerance()[dtype])
def one_hot(x, num_classes, *, dtype=jnp.float64): """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). """ dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) lhs = x[..., jnp.newaxis] rhs = lax.broadcast_to_rank(jnp.arange(num_classes, dtype=x.dtype), lhs.ndim) return jnp.array(lhs == rhs, dtype=dtype)
def enumerate_support(self, expand=True): n = self.event_shape[-1] values = jnp.identity(n, dtype=canonicalize_dtype(self.dtype)) values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,)) if expand: values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,)) return values
def fftfreq(n, d=1.0): dtype = dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.fftfreq only takes an int. " "Got n = %s." % list(n)) elif isinstance(d, (list, tuple)): raise ValueError( "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) k = jnp.zeros(n, dtype=dtype) if n % 2 == 0: # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) k = k.at[0:n // 2].set(jnp.arange(0, n // 2, dtype=dtype)) # k[n // 2:] = jnp.arange(-n // 2, -1) k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype)) else: # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) k = k.at[0:(n - 1) // 2 + 1].set( jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)) # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) k = k.at[(n - 1) // 2 + 1:].set( jnp.arange(-(n - 1) // 2, 0, dtype=dtype)) return k / (d * n)
def testDefaultTypes(self, type, dtype): for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]: y = f(type(0)) self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y)) self.assertEqual(y.dtype, dtypes.canonicalize_dtype(dtype), msg=(f, y))
def eig_abstract_eval(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( "Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = np.complex64 if dtypes.finfo( operand.dtype).bits == 32 else np.complex128 dtype = dtypes.canonicalize_dtype(dtype) vl = vr = ShapedArray(batch_dims + (n, n), dtype) w = ShapedArray(batch_dims + (n, ), dtype) else: raise NotImplementedError output = [w] if compute_left_eigenvectors: output.append(vl) if compute_right_eigenvectors: output.append(vr) return tuple(output)
def random_inputs(rng, input_shape): if type(input_shape) is tuple: return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(np.float_)) elif type(input_shape) is list: return [random_inputs(rng, shape) for shape in input_shape] else: raise TypeError(type(input_shape))
def test_gradients_with_custom_vjp(self, with_function=True): """Check gradients, for a function with custom VJP.""" @jax.custom_vjp def f(x): return x * x # f_fwd: a -> (b, residual) def f_fwd(x): return f(x), 3. * x # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): return residual * ct_b, f.defvjp(f_fwd, f_bwd) self.assertAllClose(4. * 4., f(4.)) self.assertAllClose(3. * 4., jax.grad(f)(4.)) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) self.assertAllClose(4. * 4., f_tf(jnp.float_(4.))) x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_)) with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x))
def canonicalize_res(res): res_dtype = builder.get_shape(res).numpy_dtype() jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) if res_dtype != jax_res_dtype: new_etype = xla_client.dtype_to_etype(jax_res_dtype) return xops.ConvertElementType(res, new_element_type=new_etype) else: return res
def test_randint_bounds(self, dtype): min = np.iinfo(dtype).min max = np.iinfo(dtype).max key = random.PRNGKey(1701) shape = (10,) if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits: expected = random.randint(key, shape, min, max, dtype) self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype)) else: self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
def __init__(self, value=0., log_density=0., event_dim=0, validate_args=None): if event_dim > jnp.ndim(value): raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}' .format(event_dim, jnp.ndim(value))) batch_dim = jnp.ndim(value) - event_dim batch_shape = jnp.shape(value)[:batch_dim] event_shape = jnp.shape(value)[batch_dim:] self.value = lax.convert_element_type(value, canonicalize_dtype(jnp.float64)) # NB: following Pyro implementation, log_density should be broadcasted to batch_shape self.log_density = promote_shapes(log_density, shape=batch_shape)[0] super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def _promote_arg_dtypes(*args): """Promotes `args` to a common inexact type.""" def _to_inexact_type(type): return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_ inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args] dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types)) args = [lax.convert_element_type(arg, dtype) for arg in args] if len(args) == 1: return args[0] else: return args
def test_device_put_on_numpy_scalars(self, device_put_function): device = jax.devices()[0] for dtype in _SCALAR_NUMPY_TYPES: value = dtype(0) output_buffer = device_put_function(value, device=device) self.assertFalse(output_buffer.aval.weak_type) self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype)) self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
def zeros(key, shape, dtype: DType = jnp.float_): """An initializer that returns a constant array full of zeros. The ``key`` argument is ignored. >>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray: if not res_shape.is_static(): msg = ("Compiled TensorFlow function has dynamic output shape " + f"{res_shape}. call_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compileable functions with static output shapes. " + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.") raise ValueError(msg) res_dtype = res_shape.numpy_dtype() jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
def _ravel_list(*leaves): leaves_metadata = tree_map(lambda l: pytree_metadata( jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves) leaves_idx = jnp.cumsum(jnp.array((0,) + tuple(d.size for d in leaves_metadata))) def unravel_list(arr): return [jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size), m.shape).astype(m.dtype) for i, m in enumerate(leaves_metadata)] flat = jnp.concatenate([m.flat for m in leaves_metadata]) if leaves_metadata else jnp.array([]) return flat, unravel_list
def _poisson(key, rate, shape, dtype): # Ref: https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables shape = shape or np.shape(rate) rate = lax.convert_element_type(rate, canonicalize_dtype(np.float64)) rate = np.broadcast_to(rate, shape) rng_keys = random.split(key, np.size(rate)) if xla_bridge.get_backend().platform == 'cpu': k = lax.map(_poisson_one, (rng_keys, np.reshape(rate, -1))) else: k = vmap(_poisson_one)((rng_keys, np.reshape(rate, -1))) k = lax.convert_element_type(k, dtype) return np.reshape(k, shape)
def test_device_put_on_numpy_arrays(self, device_put_function): device = jax.devices()[0] for dtype in _SCALAR_NUMPY_TYPES: value = np.zeros((3, 4), dtype=dtype) output_buffer = device_put_function(value, device=device) self.assertFalse(output_buffer.aval.weak_type) self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype)) self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype)) np.testing.assert_array_equal(output_buffer, np.zeros((3, 4), dtype=dtype))
def one_hot(x: Array, num_classes: int, *, dtype: Any = jnp.float64, axis: Union[int, AxisName] = -1) -> Array: """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). axis: the axis or axes along which the function should be computed. """ num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError( f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) lhs = lax.expand_dims(x, (axis, )) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype), rhs_shape, (output_pos_axis, )) return jnp.asarray(lhs == rhs, dtype=dtype)
def eig_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = onp.complex64 if dtypes.finfo(operand.dtype).bits == 32 else onp.complex128 dtype = dtypes.canonicalize_dtype(dtype) vl = vr = ShapedArray(batch_dims + (n, n), dtype) w = ShapedArray(batch_dims + (n,), dtype) else: raise NotImplementedError return w, vl, vr