Example #1
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     m = jax.lax.convert_element_type(self.mean, dtype)
     s = jax.lax.convert_element_type(self.stddev, dtype)
     unscaled = jax.random.truncated_normal(
         hooks.next_rng_key(), -2.0, 2.0, shape, dtype
     return s * unscaled + m
Example #2
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     if len(shape) < 2:
         raise ValueError("Orthogonal initializer requires at least a 2D shape.")
     n_rows = shape[self.axis]
     n_cols = np.prod(shape) // n_rows
     matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
     norm_dst = jax.random.normal(hooks.next_rng_key(), matrix_shape, dtype)
     q_mat, r_mat = jnp.linalg.qr(norm_dst)
     # Enforce Q is uniformly distributed
     q_mat *= jnp.sign(jnp.diag(r_mat))
     if n_rows < n_cols:
         q_mat = q_mat.T
     q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
     q_mat = jnp.moveaxis(q_mat, 0, self.axis)
     return jax.lax.convert_element_type(self.scale, dtype) * q_mat
Example #3
    def call(
        x: np.ndarray,
        training: tp.Optional[bool] = None,
        rng: tp.Optional[np.ndarray] = None,
    ) -> jnp.ndarray:
            x: The value to be dropped out.
            training: Whether training is currently happening.
            rng: Optional RNGKey.
            x but dropped out and scaled by `1 / (1 - rate)`.
        if training is None:
            training = hooks.is_training()

        return hk.dropout(
            rng=rng if rng is not None else hooks.next_rng_key(),
            rate=self.rate if training else 0.0,
Example #4
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     m = jax.lax.convert_element_type(self.mean, dtype)
     s = jax.lax.convert_element_type(self.stddev, dtype)
     return m + s * jax.random.normal(hooks.next_rng_key(), shape, dtype)
Example #5
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     return jax.random.uniform(
         hooks.next_rng_key(), shape, dtype, self.minval, self.maxval