Пример #1
0
def one_hot(x, num_classes, *, dtype=jnp.float64):
  """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

  >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
  DeviceArray([[1., 0., 0.],
               [0., 1., 0.],
               [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

  >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
  DeviceArray([[0., 0., 0.],
               [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
  """
  dtype = dtypes.canonicalize_dtype(dtype)
  x = jnp.asarray(x)
  lhs = x[..., jnp.newaxis]
  rhs = lax.broadcast_to_rank(jnp.arange(num_classes, dtype=x.dtype), lhs.ndim)
  return jnp.array(lhs == rhs, dtype=dtype)
Пример #2
0
 def init(key, shape, dtype=dtype):
     if len(shape) < 2:
         raise ValueError(
             "orthogonal initializer requires at least a 2D shape")
     n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
     matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows,
                                                              n_cols)
     A = random.normal(key, matrix_shape, dtype)
     Q, R = jnp.linalg.qr(A)
     diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
     Q *= diag_sign  # needed for a uniform distribution
     if n_rows < n_cols: Q = Q.T
     Q = jnp.reshape(
         Q,
         tuple(np.delete(shape, column_axis)) + (shape[column_axis], ))
     Q = jnp.moveaxis(Q, -1, column_axis)
     return scale * Q