Beispiel #1
0
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     if len(shape) < 2:
         raise ValueError(
             "Orthogonal initializer requires at least a 2D shape.")
     n_rows = shape[self.axis]
     n_cols = np.prod(shape) // n_rows
     matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols,
                                                              n_rows)
     norm_dst = jax.random.normal(module.next_rng_key(), matrix_shape,
                                  dtype)
     q_mat, r_mat = jnp.linalg.qr(norm_dst)
     # Enforce Q is uniformly distributed
     q_mat *= jnp.sign(jnp.diag(r_mat))
     if n_rows < n_cols:
         q_mat = q_mat.T
     q_mat = jnp.reshape(q_mat,
                         (n_rows, ) + tuple(np.delete(shape, self.axis)))
     q_mat = jnp.moveaxis(q_mat, 0, self.axis)
     return jax.lax.convert_element_type(self.scale, dtype) * q_mat
Beispiel #2
0
    def call(
        self,
        x: np.ndarray,
        training: tp.Optional[bool] = None,
        rng: tp.Optional[np.ndarray] = None,
    ) -> jnp.ndarray:
        """
        Arguments:
            x: The value to be dropped out.
            training: Whether training is currently happening.
            rng: Optional RNGKey.
        Returns:
            x but dropped out and scaled by `1 / (1 - rate)`.
        """
        if training is None:
            training = module.is_training()

        return hk.dropout(
            rng=rng if rng is not None else module.next_rng_key(),
            rate=self.rate if training else 0.0,
            x=x,
        )
Beispiel #3
0
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     m = jax.lax.convert_element_type(self.mean, dtype)
     s = jax.lax.convert_element_type(self.stddev, dtype)
     unscaled = jax.random.truncated_normal(module.next_rng_key(), -2.0,
                                            2.0, shape, dtype)
     return s * unscaled + m
Beispiel #4
0
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     m = jax.lax.convert_element_type(self.mean, dtype)
     s = jax.lax.convert_element_type(self.stddev, dtype)
     return m + s * jax.random.normal(module.next_rng_key(), shape, dtype)
Beispiel #5
0
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     return jax.random.uniform(module.next_rng_key(), shape, dtype,
                               self.minval, self.maxval)