def Init(shape, rng, nonreceptive_dims=None): """Returns random values for initializing weights of the given `shape`.""" shape = _PureShape(shape) fan_in, fan_out = _GetFans(shape, out_dim, in_dim, nonreceptive_dims) gain = scale if mode == 'fan_in': gain /= fan_in elif mode == 'fan_out': gain /= fan_out elif mode == 'fan_avg': gain /= (fan_in + fan_out) / 2 if distribution == 'truncated_normal': # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) stddev = jnp.sqrt(gain) / .87962566103423978 new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev return new_weights.astype('float32') elif distribution == 'normal': new_weights = random.normal(rng, shape) * jnp.sqrt(gain) return new_weights.astype('float32') elif distribution == 'uniform': lim = jnp.sqrt(3. * gain) return random.uniform(rng, shape, jnp.float32, -lim, lim) else: raise ValueError('invalid distribution for ScaleInitializer')
def forward(self, xs): self._validate_forward_inputs(xs) (step, layers_state) = self.state # Get N+1 rngs, N for running layers and one extra. rngs = _split_rngs(self.rng, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. self.state = (step + 1, layers_state) return xs # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers weights = self.weights if n_layers != 1 and len(weights) != n_layers: raise ValueError('number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(layers_state) != n_layers: raise ValueError('length of state ({}) not equal to number of layers ' '({})'.format(len(layers_state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using fastmath.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after w_steps = float(self._skipping_warmup_steps) f_step = step.astype(jnp.float32) warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps) # low is the minimum number of layers to *not* skip, from n_layers to 0 low = warmup * float(n_layers) # high should be so that (high - n_layers) / high = 1.0 - skip_fraction # because (high - n_layers) / high is the probability we're not skipping # (after warmup); so high - n_layers = high - high * skip_fraction high = float(n_layers) / self._skip_fraction # We want the same rng0 on all cores. if fastmath.device_count() > 1: rng0 = fastmath.psum(rng0, 'batch') n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs): inputs = _inputs_from_stack(layer, stack) def CondF(t): return layer.pure_fn(t[0], t[1], t[2], t[3]) # pylint: disable=cell-var-from-loop def PassF(t): return t[0], t[2] outputs, s = fastmath.cond( fastmath.lt(cur_layer_idx, n_forward_layers), CondF, PassF, (inputs, p, s, rng) ) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 self.state = (step + 1, new_state) return stack
def RandomUniformInitializer(lim=1.0): """Returns an initializer for random uniform coefficients.""" # Make sure shape does not contain int tensors by calling int() below. return lambda shape, rng: random.uniform( # pylint: disable=g-long-lambda rng, _PureShape(shape), jnp.float32, -lim, lim)
def AtariConvInit(kernel_shape, rng, dtype=jnp.float32): """The standard init for Conv laters and Atari.""" filter_height, filter_width, fan_in, _ = kernel_shape std = 1 / jnp.sqrt(fan_in * filter_height * filter_width) return random.uniform(rng, kernel_shape, dtype, minval=-std, maxval=std)