def forward_with_state(self, inputs, weights=tl.EMPTY_WEIGHTS, state=tl.EMPTY_STATE, rng=None): del weights del rng observations = inputs if self._mode == 'collect': # Accumulate statistics only in the collect mode, i.e. when collecting # data using the agent. for observation in observations[:, -1]: # (batch_size, time, ...) # Update statistics for each observation separately for simplicity. # Currently during data collection the batch size is 1 anyway. count = running_mean_and_variance_get_count(state) state = math.cond( count < self._sample_limit, true_operand=(observation, state), true_fun=lambda args: running_mean_and_variance_update( *args), false_operand=None, false_fun=lambda _: state, ) mean = running_mean_and_variance_get_mean(state) var = running_mean_and_variance_get_variance(state) norm_observations = (observations - mean) / (var**0.5 + self._epsilon) return (norm_observations, state)
def forward_with_state(self, xs, weights, state, rng): self._validate_forward_inputs(xs) (step, layers_state) = state # Get N+1 rngs, N for running layers and one extra. rngs = _split_rngs(rng, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. return (xs, (step + 1, layers_state)) # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers if n_layers != 1 and len(weights) != n_layers: raise ValueError( 'number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(layers_state) != n_layers: raise ValueError( 'length of state ({}) not equal to number of layers ' '({})'.format(len(layers_state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using math.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after w_steps = float(self._skipping_warmup_steps) warmup = np.maximum(0.0, (w_steps - step.astype(np.float32)) / w_steps) # low is the minimum number of layers to *not* skip, from n_layers to 0 low = warmup * float(n_layers) # high should be so that (high - n_layers) / high = 1.0 - skip_fraction # because (high - n_layers) / high is the probability we're not skipping # (after warmup); so high - n_layers = high - high * skip_fraction high = float(n_layers) / self._skip_fraction # We want the same rng0 on all cores. if math.device_count() > 1: rng0 = math.psum(rng0, 'batch') n_forward_layers = random.uniform(rng0, (), np.float32, low, high) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs): inputs = _inputs_from_stack(layer, stack) outputs, s = math.cond( # Skip (do identity) if > n_forward_layers. pred=(math.lt(cur_layer_idx, n_forward_layers)), true_operand=(inputs, p, s, rng), # This tuple is t below. true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])), # pylint: disable=cell-var-from-loop false_operand=(inputs, p, s, rng), false_fun=(lambda t: (t[0], t[2])), # return (inputs, state) ) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 return stack, (step + 1, new_state)
def forward_with_state(self, xs, weights=tl.EMPTY_WEIGHTS, state=tl.EMPTY_STATE, **kwargs): self._validate_forward_inputs(xs) # Get N+1 rngs, N for running layers and one extra. rngs = _pop_rng_and_split(kwargs, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. return (xs, state) # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers if n_layers != 1 and len(weights) != n_layers: raise ValueError( 'number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(state) != n_layers: raise ValueError( 'length of state ({}) not equal to number of layers ' '({})'.format(len(state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using math.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': n_forward_layers = random.uniform(rng0, (), np.float32, 0.0, n_layers) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, state, rngs): inputs = _inputs_from_stack(layer, stack) outputs, s = math.cond( # Skip (do identity) if > n_forward_layers. pred=(math.lt(cur_layer_idx, n_forward_layers)), true_operand=(inputs, p, s, rng), # This tuple is t below. true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])), # pylint: disable=cell-var-from-loop false_operand=(inputs, p, s, rng), false_fun=(lambda t: (t[0], t[2])), # return (inputs, state) ) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 return stack, new_state