def __call__(self, shape: Shape, dtype: DType) -> np.ndarray: m = jax.lax.convert_element_type(self.mean, dtype) s = jax.lax.convert_element_type(self.stddev, dtype) unscaled = jax.random.truncated_normal( hooks.next_rng_key(), -2.0, 2.0, shape, dtype ) return s * unscaled + m
def __call__(self, shape: Shape, dtype: DType) -> np.ndarray: if len(shape) < 2: raise ValueError("Orthogonal initializer requires at least a 2D shape.") n_rows = shape[self.axis] n_cols = np.prod(shape) // n_rows matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows) norm_dst = jax.random.normal(hooks.next_rng_key(), matrix_shape, dtype) q_mat, r_mat = jnp.linalg.qr(norm_dst) # Enforce Q is uniformly distributed q_mat *= jnp.sign(jnp.diag(r_mat)) if n_rows < n_cols: q_mat = q_mat.T q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis))) q_mat = jnp.moveaxis(q_mat, 0, self.axis) return jax.lax.convert_element_type(self.scale, dtype) * q_mat
def call( self, x: np.ndarray, training: tp.Optional[bool] = None, rng: tp.Optional[np.ndarray] = None, ) -> jnp.ndarray: """ Arguments: x: The value to be dropped out. training: Whether training is currently happening. rng: Optional RNGKey. Returns: x but dropped out and scaled by `1 / (1 - rate)`. """ if training is None: training = hooks.is_training() return hk.dropout( rng=rng if rng is not None else hooks.next_rng_key(), rate=self.rate if training else 0.0, x=x, )
def __call__(self, shape: Shape, dtype: DType) -> np.ndarray: m = jax.lax.convert_element_type(self.mean, dtype) s = jax.lax.convert_element_type(self.stddev, dtype) return m + s * jax.random.normal(hooks.next_rng_key(), shape, dtype)
def __call__(self, shape: Shape, dtype: DType) -> np.ndarray: return jax.random.uniform( hooks.next_rng_key(), shape, dtype, self.minval, self.maxval )