def forward(self, xs): # TODO(jaszczur): modify; it's a copy from SkippingSerial self._validate_forward_inputs(xs) layers_state = self.state # Get 3 rngs, one for each layer. rngs = _split_rngs(self.rng, 3) # Prepare the stack and do some safety checks as in the parent class. stack = _make_tuple(xs) weights = self.weights if len(weights) != 3: raise ValueError('number of weights ({}) not equal to 3' .format(len(weights))) if len(layers_state) != 3: raise ValueError('length of state ({}) not equal to 3' .format(len(layers_state))) def true_func(t): outputs, new_true_state = self._true.pure_fn( t[0][0], t[1][0], t[2][0], t[3][0]) # t[2][1] is old_false_state which is not changing if true is executed. return outputs, (new_true_state, t[2][1]) def false_func(t): if self._identity_false_fun: # Memory optimization: we don't need pure_fn call. return t[0][1], t[2] outputs, new_false_state = self._false.pure_fn( t[0][1], t[1][1], t[2][1], t[3][1]) # t[2][1] is old_true_state, which is not changing if false is executed. return outputs, (t[2][0], new_false_state) cond_inputs = _inputs_from_stack(self._cond, xs) cond_output, s = self._cond.pure_fn(cond_inputs, self.weights[0], self.state[0], rngs[0], use_cache=True) stack = _outputs_onto_stack(self._cond, [], stack) self._cond.state = s outputs, both_states = fastmath.cond( cond_output, true_func, false_func, [(stack, stack), (self.weights[1], self.weights[2]), (self.state[1], self.state[2]), (rngs[1], rngs[2])] ) stack = _outputs_onto_stack(self._cond, [], stack) # We don't know which (`true` or `false`) branch was run, but both of them # are adding (n_out) and removing (n_in) the same number of elements of the # stack (this was checked in __init__). _outputs_onto_stack just uses the # layer's n_in and n_out, so we can pass either `true` or `false` to it. # Note that `outputs` is the actual output of `true` or `false` branch, # whichever was run, and we add it to the stack in any case. stack = _outputs_onto_stack(self._true, outputs, stack) self._true.state = both_states[0] self._false.state = both_states[1] return _make_singleitem_or_original(stack)
def forward(self, inputs): state = self.state observations = inputs if self._mode == 'collect': # Accumulate statistics only in the collect mode, i.e. when collecting # data using the agent. for observation in observations[:, -1]: # (batch_size, time, ...) # Update statistics for each observation separately for simplicity. # Currently during data collection the batch size is 1 anyway. count = running_mean_and_variance_get_count(state) state = fastmath.cond( count < self._sample_limit, true_operand=(observation, state), true_fun=lambda args: running_mean_and_variance_update(*args), false_operand=None, false_fun=lambda _: state, ) mean = running_mean_and_variance_get_mean(state) var = running_mean_and_variance_get_variance(state) norm_observations = (observations - mean) / (var ** 0.5 + self._epsilon) self.state = state return norm_observations
def forward(self, xs): self._validate_forward_inputs(xs) (step, layers_state) = self.state # Get N+1 rngs, N for running layers and one extra. rngs = _split_rngs(self.rng, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. self.state = (step + 1, layers_state) return xs # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers weights = self.weights if n_layers != 1 and len(weights) != n_layers: raise ValueError('number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(layers_state) != n_layers: raise ValueError('length of state ({}) not equal to number of layers ' '({})'.format(len(layers_state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using fastmath.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after w_steps = float(self._skipping_warmup_steps) f_step = step.astype(jnp.float32) warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps) # low is the minimum number of layers to *not* skip, from n_layers to 0 low = warmup * float(n_layers) # high should be so that (high - n_layers) / high = 1.0 - skip_fraction # because (high - n_layers) / high is the probability we're not skipping # (after warmup); so high - n_layers = high - high * skip_fraction high = float(n_layers) / self._skip_fraction # We want the same rng0 on all cores. if fastmath.device_count() > 1: rng0 = fastmath.psum(rng0, 'batch') n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs): inputs = _inputs_from_stack(layer, stack) def CondF(t): return layer.pure_fn(t[0], t[1], t[2], t[3]) # pylint: disable=cell-var-from-loop def PassF(t): return t[0], t[2] outputs, s = fastmath.cond( fastmath.lt(cur_layer_idx, n_forward_layers), CondF, PassF, (inputs, p, s, rng) ) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 self.state = (step + 1, new_state) return stack