def init(key, shape, dtype=dtype): shape = core.as_named_shape(shape) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) variance = jnp.sqrt(1 / fan_in) return random.uniform(key, (int(fan_out), ), dtype, minval=-variance, maxval=variance)
def threefry_random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) for name, size in shape.named_items: real_size = lax.psum(1, name) if real_size != size: raise ValueError( f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) key = threefry_fold_in(key, axis_index) size = prod(shape.positional) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1, )), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)))) bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape)
def init(key, shape, dtype=dtype): shape = core.as_named_shape(shape) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 else: raise ValueError( "invalid mode for variance scaling initializer: {}".format( mode)) variance = jnp.array(scale / denominator, dtype=dtype) if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) return random.truncated_normal(key, -2, 2, shape, dtype) * stddev elif distribution == "normal": return random.normal(key, shape, dtype) * jnp.sqrt(variance) elif distribution == "uniform": return random.uniform(key, shape, dtype, -1) * jnp.sqrt( 3 * variance) else: raise ValueError( "invalid distribution for variance scaling initializer")
def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) shape = core.as_named_shape(shape) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 else: raise ValueError( f"invalid mode for variance scaling initializer: {mode}") variance = jnp.array(scale / denominator, dtype=dtype) if distribution == "truncated_normal": if jnp.issubdtype(dtype, jnp.floating): # constant is stddev of standard normal truncated to (-2, 2) stddev = jnp.sqrt(variance) / jnp.array( .87962566103423978, dtype) return random.truncated_normal(key, -2, 2, shape, dtype) * stddev else: # constant is stddev of complex standard normal truncated to 2 stddev = jnp.sqrt(variance) / jnp.array( .95311164380491208, dtype) return _complex_truncated_normal(key, 2, shape, dtype) * stddev elif distribution == "normal": return random.normal(key, shape, dtype) * jnp.sqrt(variance) elif distribution == "uniform": if jnp.issubdtype(dtype, jnp.floating): return random.uniform(key, shape, dtype, -1) * jnp.sqrt( 3 * variance) else: return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance) else: raise ValueError( f"invalid distribution for variance scaling initializer: {distribution}" )