def Init(shape, rng): """Returns orthogonalized random normal values with the given `shape`.""" # Have at least 2 elements in shape. cur_shape = list(shape) while len(cur_shape) < 2: cur_shape = [1] + cur_shape # Flatten the input shape with the last dimension remaining. n_rows = 1 for dim in cur_shape[:-1]: n_rows *= dim n_cols = cur_shape[-1] flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) # Generate a random matrix a = random.normal(rng, flat_shape, dtype=jnp.float32) # Compute the qr factorization q, r = jnp.linalg.qr(a) # Make Q uniform d = jnp.diag(r) q *= jnp.sign(d) # Transpose and reshape back q if needed. if n_rows < n_cols: q = jnp.transpose(q) q = jnp.reshape(q, shape) # Return scaled as requested. return stddev * q
def Init(shape, rng): """Returns random values for initializing weights of the given `shape`.""" fan_in, fan_out = _GetFans(shape, out_dim, in_dim) gain = scale if mode == 'fan_in': gain /= fan_in elif mode == 'fan_out': gain /= fan_out elif mode == 'fan_avg': gain /= (fan_in + fan_out) / 2 if distribution == 'truncated_normal': # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) stddev = jnp.sqrt(gain) / .87962566103423978 new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev return new_weights.astype('float32') elif distribution == 'normal': new_weights = random.normal(rng, shape) * jnp.sqrt(gain) return new_weights.astype('float32') elif distribution == 'uniform': lim = jnp.sqrt(3. * gain) return random.uniform(rng, shape, jnp.float32, -lim, lim) else: raise ValueError('invalid distribution for ScaleInitializer')
def RandomNormalInitializer(stddev=1e-2): """Returns an initializer for random normal coefficients.""" return (lambda shape, rng: (stddev * random.normal(rng, shape)).astype('float32'))
def RandomNormalInitializer(stddev=1e-2): """Returns an initializer for random normal coefficients.""" return lambda shape, rng: ( stddev * random.normal( # pylint: disable=g-long-lambda rng, _PureShape(shape)).astype('float32'))