def one_hot(x, num_classes, *, dtype=jnp.float64): """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). """ dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) lhs = x[..., jnp.newaxis] rhs = lax.broadcast_to_rank(jnp.arange(num_classes, dtype=x.dtype), lhs.ndim) return jnp.array(lhs == rhs, dtype=dtype)
def init(key, shape, dtype=dtype): if len(shape) < 2: raise ValueError( "orthogonal initializer requires at least a 2D shape") n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis] matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) A = random.normal(key, matrix_shape, dtype) Q, R = jnp.linalg.qr(A) diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim) Q *= diag_sign # needed for a uniform distribution if n_rows < n_cols: Q = Q.T Q = jnp.reshape( Q, tuple(np.delete(shape, column_axis)) + (shape[column_axis], )) Q = jnp.moveaxis(Q, -1, column_axis) return scale * Q