Example #1
0
    def Init(shape, rng):
        """Returns orthogonalized random normal values with the given `shape`."""
        # Have at least 2 elements in shape.
        cur_shape = list(shape)
        while len(cur_shape) < 2:
            cur_shape = [1] + cur_shape

        # Flatten the input shape with the last dimension remaining.
        n_rows = 1
        for dim in cur_shape[:-1]:
            n_rows *= dim
        n_cols = cur_shape[-1]
        flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)

        # Generate a random matrix
        a = random.normal(rng, flat_shape, dtype=jnp.float32)

        # Compute the qr factorization
        q, r = jnp.linalg.qr(a)

        # Make Q uniform
        d = jnp.diag(r)
        q *= jnp.sign(d)

        # Transpose and reshape back q if needed.
        if n_rows < n_cols:
            q = jnp.transpose(q)
        q = jnp.reshape(q, shape)

        # Return scaled as requested.
        return stddev * q
Example #2
0
 def Init(shape, rng, nonreceptive_dims=None):
     """Returns random values for initializing weights of the given `shape`."""
     shape = _PureShape(shape)
     fan_in, fan_out = _GetFans(shape, out_dim, in_dim, nonreceptive_dims)
     gain = scale
     if mode == 'fan_in':
         gain /= fan_in
     elif mode == 'fan_out':
         gain /= fan_out
     elif mode == 'fan_avg':
         gain /= (fan_in + fan_out) / 2
     if distribution == 'truncated_normal':
         # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
         stddev = jnp.sqrt(gain) / .87962566103423978
         new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev
         return new_weights.astype('float32')
     elif distribution == 'normal':
         new_weights = random.normal(rng, shape) * jnp.sqrt(gain)
         return new_weights.astype('float32')
     elif distribution == 'uniform':
         lim = jnp.sqrt(3. * gain)
         return random.uniform(rng, shape, jnp.float32, -lim, lim)
     else:
         raise ValueError('invalid distribution for ScaleInitializer')
Example #3
0
def RandomNormalInitializer(stddev=1e-2):
    """Returns an initializer for random normal coefficients."""
    return lambda shape, rng: (
        stddev * random.normal(  # pylint: disable=g-long-lambda
            rng, _PureShape(shape)).astype('float32'))