Beispiel #1
0
    def test_device_put_on_python_scalars(self):
        device = jax.devices()[0]
        int_type = dtypes.canonicalize_dtype(np.int64)
        float_type = dtypes.canonicalize_dtype(np.float64)
        complex_type = dtypes.canonicalize_dtype(np.complex128)

        # int
        res = _cpp_device_put(1, device).to_py()
        self.assertEqual(res, 1)
        self.assertEqual(res.dtype, int_type)
        # We also compare to the Python Jax API, to make sure we have the exact
        # same behavior. When Jax removes the flag and removes this feature, this
        # test will fail.
        self.assertEqual(jnp.asarray(1).dtype, res.dtype)

        # float
        res = _cpp_device_put(1.0, device).to_py()
        self.assertEqual(res, 1.0)
        self.assertEqual(res.dtype, float_type)
        self.assertEqual(jnp.asarray(1.0).dtype, res.dtype)

        # bool
        for bool_value in [True, False]:
            res = _cpp_device_put(bool_value, device).to_py()
            self.assertEqual(res, np.asarray(bool_value))
            self.assertEqual(res.dtype, np.bool_)
            self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)

        # Complex
        res = _cpp_device_put(1 + 1j, device).to_py()
        self.assertEqual(res, 1 + 1j)
        self.assertEqual(res.dtype, complex_type)
        self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
Beispiel #2
0
def test_ravel_pytree(pytree):
    flat, unravel_fn = ravel_pytree(pytree)
    unravel = unravel_fn(flat)
    tree_flatten(tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree))
    assert all(tree_flatten(tree_multimap(lambda x, y:
                                          canonicalize_dtype(lax.dtype(x)) == canonicalize_dtype(lax.dtype(y)),
                                          unravel, pytree))[0])
Beispiel #3
0
    def test_arg_signature_of_value(self):
        """Tests the C++ code-path."""
        jax_enable_x64 = config.x64_enabled

        # 1. Numpy scalar types
        for dtype in _SCALAR_NUMPY_TYPES:
            value = dtype(0)

            signature = jaxlib.jax_jit._ArgSignatureOfValue(
                value, jax_enable_x64)
            self.assertEqual(signature.dtype, jax.device_put(value).dtype)
            self.assertEqual(signature.shape, ())
            self.assertFalse(signature.weak_type)

        # 2. Numpy arrays
        for dtype in _SCALAR_NUMPY_TYPES:
            value = np.zeros((3, 4), dtype=dtype)

            signature = jaxlib.jax_jit._ArgSignatureOfValue(
                value, jax_enable_x64)
            self.assertEqual(signature.dtype, jax.device_put(value).dtype)
            self.assertEqual(signature.shape, (3, 4))
            self.assertFalse(signature.weak_type)

        int_type = dtypes.canonicalize_dtype(np.int64)
        float_type = dtypes.canonicalize_dtype(np.float64)
        complex_type = dtypes.canonicalize_dtype(np.complex128)

        # 3. Python scalar types
        # int
        signature = jaxlib.jax_jit._ArgSignatureOfValue(1, jax_enable_x64)
        self.assertEqual(signature.dtype, jax.device_put(1).dtype)
        self.assertEqual(signature.dtype, int_type)
        self.assertEqual(signature.shape, ())
        self.assertTrue(signature.weak_type)
        # float
        signature = jaxlib.jax_jit._ArgSignatureOfValue(1.0, jax_enable_x64)
        self.assertEqual(signature.dtype, jax.device_put(1.0).dtype)
        self.assertEqual(signature.dtype, float_type)
        self.assertEqual(signature.shape, ())
        self.assertTrue(signature.weak_type)
        # bool
        for bool_value in [True, False]:
            signature = jaxlib.jax_jit._ArgSignatureOfValue(
                bool_value, jax_enable_x64)
            self.assertEqual(signature.dtype, jax.device_put(bool_value).dtype)
            self.assertEqual(signature.dtype, np.bool_)
            self.assertEqual(signature.shape, ())
            self.assertTrue(signature.weak_type)
        # Complex
        signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64)
        self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype)
        self.assertEqual(signature.dtype, complex_type)
        self.assertEqual(signature.shape, ())
        self.assertTrue(signature.weak_type)
Beispiel #4
0
  def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
    """Compares dtypes across JAX and TF dtypes. Overrides super method."""
    def to_numpy_dtype(dt):
      return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype

    if not config.FLAGS.jax_enable_x64 and canonicalize_dtypes:
      self.assertEqual(dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))),
                       dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y))))
    else:
      self.assertEqual(to_numpy_dtype(jtu._dtype(x)),
                       to_numpy_dtype(jtu._dtype(y)))
Beispiel #5
0
 def init(key, shape, dtype=dtype):
     dtype = dtypes.canonicalize_dtype(dtype)
     if len(shape) not in [3, 4, 5]:
         raise ValueError(
             "Delta orthogonal initializer requires a 3D, 4D or 5D "
             "shape.")
     if shape[-1] < shape[-2]:
         raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
     ortho_init = orthogonal(scale=scale,
                             column_axis=column_axis,
                             dtype=dtype)
     ortho_matrix = ortho_init(key, shape[-2:])
     W = jnp.zeros(shape, dtype=dtype)
     if len(shape) == 3:
         k = shape[0]
         return ops.index_update(W, ops.index[(k - 1) // 2, ...],
                                 ortho_matrix)
     elif len(shape) == 4:
         k1, k2 = shape[:2]
         return ops.index_update(
             W, ops.index[(k1 - 1) // 2, (k2 - 1) // 2, ...], ortho_matrix)
     else:
         k1, k2, k3 = shape[:3]
         return ops.index_update(
             W, ops.index[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2, ...],
             ortho_matrix)
Beispiel #6
0
 def test_canonicalize_type(self):
   expected = {
       True: _EXPECTED_CANONICALIZE_X64,
       False: _EXPECTED_CANONICALIZE_X32,
   }
   for in_dtype, expected_dtype in expected[FLAGS.jax_enable_x64].items():
     self.assertEqual(dtypes.canonicalize_dtype(in_dtype), expected_dtype)
Beispiel #7
0
    def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype,
                           order, mode, cval, impl, round_, rng_factory):
        def args_maker():
            x = np.arange(prod(shape), dtype=dtype).reshape(shape)
            coords = [(size - 1) * rng(coords_shape, coords_dtype)
                      for size in shape]
            if round_:
                coords = [c.round().astype(int) for c in coords]
            return x, coords

        rng = rng_factory(self.rng())
        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(
            x, c, order=order, mode=mode, cval=cval)
        impl_fun = (osp_ndimage.map_coordinates
                    if impl == "original" else _fixed_ref_map_coordinates)
        osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval)
        if dtype in float_dtypes:
            epsilon = max([
                dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
                for d in [dtype, coords_dtype]
            ])
            self._CheckAgainstNumpy(osp_op,
                                    lsp_op,
                                    args_maker,
                                    tol=100 * epsilon)
        else:
            self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)
Beispiel #8
0
    def __init__(self,
                 v=0.,
                 log_density=0.,
                 event_dim=0,
                 validate_args=None,
                 value=None):
        if value is not None:
            v = value
            warnings.warn(
                "`value` argument has been deprecated in favor of `v` argument.",
                FutureWarning)

        if event_dim > jnp.ndim(v):
            raise ValueError(
                'Expected event_dim <= v.dim(), actual {} vs {}'.format(
                    event_dim, jnp.ndim(v)))
        batch_dim = jnp.ndim(v) - event_dim
        batch_shape = jnp.shape(v)[:batch_dim]
        event_shape = jnp.shape(v)[batch_dim:]
        self.v = lax.convert_element_type(v, canonicalize_dtype(jnp.float64))
        # NB: following Pyro implementation, log_density should be broadcasted to batch_shape
        self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
        super(Delta, self).__init__(batch_shape,
                                    event_shape,
                                    validate_args=validate_args)
Beispiel #9
0
  def test_gradients_with_custom_jvp(self, with_function=True):
    """Check gradients, for a function with custom JVP."""
    @jax.custom_jvp
    def f(x):
      return x * x

    @f.defjvp
    def f_jvp(primals, tangents):
      # 3 * x * x_t
      x, = primals
      x_dot, = tangents
      primal_out = f(x)
      tangent_out = 3. * x * x_dot
      return primal_out, tangent_out

    self.assertAllClose(4. * 4., f(4.))
    self.assertAllClose(3. * 4., jax.grad(f)(4.))

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    self.assertAllClose(4. * 4., f_tf(jnp.float_(4.)))
    x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
    with tf.GradientTape() as tape:
      tape.watch(x)
      y = f_tf(x)

    self.assertAllClose(4. * 4., y)
    self.assertAllClose(3. * 4., tape.gradient(y, x))
Beispiel #10
0
 def testBinaryPromotion(self, swap, jit):
   testcases = [
     (jnp.array(1.), 0., jnp.float_),
     (jnp.array(1.), jnp.array(0.), jnp.float_),
     (jnp.array(1.), jnp.array(0., dtype=jnp.float16), jnp.float_),
     (jnp.array(1.), jnp.array(0., dtype=jnp.float32), jnp.float_),
     (jnp.array(1.), jnp.array(0., dtype=jnp.float64), jnp.float64),
     (jnp.array(1., dtype=jnp.float16), 0., jnp.float16),
     (jnp.array(1., dtype=jnp.float32), 0., jnp.float32),
     (jnp.array(1., dtype=jnp.float64), 0., jnp.float64),
     (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float16), jnp.float16),
     (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32),
     (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float64), jnp.float64),
     (jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float32), jnp.float32),
     (jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float64), jnp.float64),
     (jnp.array(1., dtype=jnp.float64), jnp.array(0., dtype=jnp.float64), jnp.float64),
     (jnp.array([1.]), 0., jnp.float_),
     (jnp.array([1.]), jnp.array(0.), jnp.float_),
     (jnp.array([1.]), jnp.array(0., dtype=jnp.float16), jnp.float_),
     (jnp.array([1.]), jnp.array(0., dtype=jnp.float32), jnp.float_),
     (jnp.array([1.]), jnp.array(0., dtype=jnp.float64), jnp.float64),
     (jnp.array([1.], dtype=jnp.float32), jnp.array(0., dtype=jnp.float16), jnp.float32),
     (jnp.array([1.], dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32),
     (jnp.array([1.], dtype=jnp.float16), 0., jnp.float16),
   ]
   op = jax.jit(operator.add) if jit else operator.add
   for x, y, dtype in testcases:
     x, y = (y, x) if swap else (x, y)
     z = x + y
     self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
     self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))
Beispiel #11
0
def _tolerance(dtype: onp.dtype, tol: Optional[float] = None) -> float:
  tol = {} if tol is None else tol
  if not isinstance(tol, dict):
    return tol
  tol = {onp.dtype(key): value for key, value in tol.items()}
  dtype = _dtypes.canonicalize_dtype(onp.dtype(dtype))
  return tol.get(dtype, _default_tolerance()[dtype])
Beispiel #12
0
def one_hot(x, num_classes, *, dtype=jnp.float64):
  """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

  >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
  DeviceArray([[1., 0., 0.],
               [0., 1., 0.],
               [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

  >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
  DeviceArray([[0., 0., 0.],
               [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
  """
  dtype = dtypes.canonicalize_dtype(dtype)
  x = jnp.asarray(x)
  lhs = x[..., jnp.newaxis]
  rhs = lax.broadcast_to_rank(jnp.arange(num_classes, dtype=x.dtype), lhs.ndim)
  return jnp.array(lhs == rhs, dtype=dtype)
Beispiel #13
0
 def enumerate_support(self, expand=True):
     n = self.event_shape[-1]
     values = jnp.identity(n, dtype=canonicalize_dtype(self.dtype))
     values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,))
     if expand:
         values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,))
     return values
Beispiel #14
0
def fftfreq(n, d=1.0):
    dtype = dtypes.canonicalize_dtype(jnp.float_)
    if isinstance(n, (list, tuple)):
        raise ValueError(
            "The n argument of jax.numpy.fft.fftfreq only takes an int. "
            "Got n = %s." % list(n))

    elif isinstance(d, (list, tuple)):
        raise ValueError(
            "The d argument of jax.numpy.fft.fftfreq only takes a single value. "
            "Got d = %s." % list(d))

    k = jnp.zeros(n, dtype=dtype)
    if n % 2 == 0:
        # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
        k = k.at[0:n // 2].set(jnp.arange(0, n // 2, dtype=dtype))

        # k[n // 2:] = jnp.arange(-n // 2, -1)
        k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype))

    else:
        # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
        k = k.at[0:(n - 1) // 2 + 1].set(
            jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype))

        # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
        k = k.at[(n - 1) // 2 + 1:].set(
            jnp.arange(-(n - 1) // 2, 0, dtype=dtype))

    return k / (d * n)
Beispiel #15
0
 def testDefaultTypes(self, type, dtype):
     for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]:
         y = f(type(0))
         self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y))
         self.assertEqual(y.dtype,
                          dtypes.canonicalize_dtype(dtype),
                          msg=(f, y))
Beispiel #16
0
def eig_abstract_eval(operand, *, compute_left_eigenvectors,
                      compute_right_eigenvectors):
    if isinstance(operand, ShapedArray):
        if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
            raise ValueError(
                "Argument to nonsymmetric eigendecomposition must have "
                "shape [..., n, n], got shape {}".format(operand.shape))

        batch_dims = operand.shape[:-2]
        n = operand.shape[-1]
        dtype = np.complex64 if dtypes.finfo(
            operand.dtype).bits == 32 else np.complex128
        dtype = dtypes.canonicalize_dtype(dtype)
        vl = vr = ShapedArray(batch_dims + (n, n), dtype)
        w = ShapedArray(batch_dims + (n, ), dtype)
    else:
        raise NotImplementedError

    output = [w]
    if compute_left_eigenvectors:
        output.append(vl)
    if compute_right_eigenvectors:
        output.append(vr)

    return tuple(output)
Beispiel #17
0
def random_inputs(rng, input_shape):
  if type(input_shape) is tuple:
    return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(np.float_))
  elif type(input_shape) is list:
    return [random_inputs(rng, shape) for shape in input_shape]
  else:
    raise TypeError(type(input_shape))
Beispiel #18
0
  def test_gradients_with_custom_vjp(self, with_function=True):
    """Check gradients, for a function with custom VJP."""
    @jax.custom_vjp
    def f(x):
      return x * x

    # f_fwd: a -> (b, residual)
    def f_fwd(x):
      return f(x), 3. * x
    # f_bwd: (residual, CT b) -> [CT a]
    def f_bwd(residual, ct_b):
      return residual * ct_b,

    f.defvjp(f_fwd, f_bwd)

    self.assertAllClose(4. * 4., f(4.))
    self.assertAllClose(3. * 4., jax.grad(f)(4.))

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    self.assertAllClose(4. * 4., f_tf(jnp.float_(4.)))
    x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
    with tf.GradientTape() as tape:
      tape.watch(x)
      y = f_tf(x)

    self.assertAllClose(4. * 4., y)
    self.assertAllClose(3. * 4., tape.gradient(y, x))
Beispiel #19
0
 def canonicalize_res(res):
     res_dtype = builder.get_shape(res).numpy_dtype()
     jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
     if res_dtype != jax_res_dtype:
         new_etype = xla_client.dtype_to_etype(jax_res_dtype)
         return xops.ConvertElementType(res, new_element_type=new_etype)
     else:
         return res
Beispiel #20
0
 def test_randint_bounds(self, dtype):
   min = np.iinfo(dtype).min
   max = np.iinfo(dtype).max
   key = random.PRNGKey(1701)
   shape = (10,)
   if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
     expected = random.randint(key, shape, min, max, dtype)
     self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype))
   else:
     self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
Beispiel #21
0
 def __init__(self, value=0., log_density=0., event_dim=0, validate_args=None):
     if event_dim > jnp.ndim(value):
         raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}'
                          .format(event_dim, jnp.ndim(value)))
     batch_dim = jnp.ndim(value) - event_dim
     batch_shape = jnp.shape(value)[:batch_dim]
     event_shape = jnp.shape(value)[batch_dim:]
     self.value = lax.convert_element_type(value, canonicalize_dtype(jnp.float64))
     # NB: following Pyro implementation, log_density should be broadcasted to batch_shape
     self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
     super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)
Beispiel #22
0
def _promote_arg_dtypes(*args):
  """Promotes `args` to a common inexact type."""
  def _to_inexact_type(type):
    return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
  inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
  dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
  args = [lax.convert_element_type(arg, dtype) for arg in args]
  if len(args) == 1:
    return args[0]
  else:
    return args
Beispiel #23
0
  def test_device_put_on_numpy_scalars(self, device_put_function):

    device = jax.devices()[0]
    for dtype in _SCALAR_NUMPY_TYPES:
      value = dtype(0)

      output_buffer = device_put_function(value, device=device)

      self.assertFalse(output_buffer.aval.weak_type)
      self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype))
      self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
Beispiel #24
0
def zeros(key, shape, dtype: DType = jnp.float_):
    """An initializer that returns a constant array full of zeros.

  The ``key`` argument is ignored.

  >>> import jax, jax.numpy as jnp
  >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
  DeviceArray([[0., 0., 0.],
               [0., 0., 0.]], dtype=float32)
  """
    return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
Beispiel #25
0
  def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray:
    if not res_shape.is_static():
      msg = ("Compiled TensorFlow function has dynamic output shape " +
             f"{res_shape}. call_tf can used " +
             "in a staged context (under jax.jit, lax.scan, etc.) only with " +
             "compileable functions with static output shapes. " +
             "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
      raise ValueError(msg)

    res_dtype = res_shape.numpy_dtype()
    jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
    return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
Beispiel #26
0
def _ravel_list(*leaves):
    leaves_metadata = tree_map(lambda l: pytree_metadata(
        jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves)
    leaves_idx = jnp.cumsum(jnp.array((0,) + tuple(d.size for d in leaves_metadata)))

    def unravel_list(arr):
        return [jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                            m.shape).astype(m.dtype)
                for i, m in enumerate(leaves_metadata)]

    flat = jnp.concatenate([m.flat for m in leaves_metadata]) if leaves_metadata else jnp.array([])
    return flat, unravel_list
Beispiel #27
0
def _poisson(key, rate, shape, dtype):
    # Ref: https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables
    shape = shape or np.shape(rate)
    rate = lax.convert_element_type(rate, canonicalize_dtype(np.float64))
    rate = np.broadcast_to(rate, shape)
    rng_keys = random.split(key, np.size(rate))
    if xla_bridge.get_backend().platform == 'cpu':
        k = lax.map(_poisson_one, (rng_keys, np.reshape(rate, -1)))
    else:
        k = vmap(_poisson_one)((rng_keys, np.reshape(rate, -1)))
    k = lax.convert_element_type(k, dtype)
    return np.reshape(k, shape)
Beispiel #28
0
  def test_device_put_on_numpy_arrays(self, device_put_function):

    device = jax.devices()[0]
    for dtype in _SCALAR_NUMPY_TYPES:
      value = np.zeros((3, 4), dtype=dtype)
      output_buffer = device_put_function(value, device=device)

      self.assertFalse(output_buffer.aval.weak_type)
      self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype))
      self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
      np.testing.assert_array_equal(output_buffer, np.zeros((3, 4),
                                                            dtype=dtype))
Beispiel #29
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float64,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
Beispiel #30
0
def eig_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError("Argument to nonsymmetric eigendecomposition must have "
                       "shape [..., n, n], got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    dtype = onp.complex64 if dtypes.finfo(operand.dtype).bits == 32 else onp.complex128
    dtype = dtypes.canonicalize_dtype(dtype)
    vl = vr = ShapedArray(batch_dims + (n, n), dtype)
    w = ShapedArray(batch_dims + (n,), dtype)
  else:
    raise NotImplementedError
  return w, vl, vr