Example #1
0
  def forward(self, xs):
    # TODO(jaszczur): modify; it's a copy from SkippingSerial
    self._validate_forward_inputs(xs)
    layers_state = self.state
    # Get 3 rngs, one for each layer.
    rngs = _split_rngs(self.rng, 3)

    # Prepare the stack and do some safety checks as in the parent class.
    stack = _make_tuple(xs)
    weights = self.weights
    if len(weights) != 3:
      raise ValueError('number of weights ({}) not equal to 3'
                       .format(len(weights)))
    if len(layers_state) != 3:
      raise ValueError('length of state ({}) not equal to 3'
                       .format(len(layers_state)))

    def true_func(t):
      outputs, new_true_state = self._true.pure_fn(
          t[0][0], t[1][0], t[2][0], t[3][0])
      # t[2][1] is old_false_state which is not changing if true is executed.
      return outputs, (new_true_state, t[2][1])

    def false_func(t):
      if self._identity_false_fun:
        # Memory optimization: we don't need pure_fn call.
        return t[0][1], t[2]
      outputs, new_false_state = self._false.pure_fn(
          t[0][1], t[1][1], t[2][1], t[3][1])
      # t[2][1] is old_true_state, which is not changing if false is executed.
      return outputs, (t[2][0], new_false_state)

    cond_inputs = _inputs_from_stack(self._cond, xs)
    cond_output, s = self._cond.pure_fn(cond_inputs, self.weights[0],
                                        self.state[0], rngs[0], use_cache=True)
    stack = _outputs_onto_stack(self._cond, [], stack)
    self._cond.state = s

    outputs, both_states = fastmath.cond(
        cond_output,
        true_func,
        false_func,
        [(stack, stack),
         (self.weights[1], self.weights[2]),
         (self.state[1], self.state[2]),
         (rngs[1], rngs[2])]
    )
    stack = _outputs_onto_stack(self._cond, [], stack)

    # We don't know which (`true` or `false`) branch was run, but both of them
    # are adding (n_out) and removing (n_in) the same number of elements of the
    # stack (this was checked in __init__). _outputs_onto_stack just uses the
    # layer's n_in and n_out, so we can pass either `true` or `false` to it.
    # Note that `outputs` is the actual output of `true` or `false` branch,
    # whichever was run, and we add it to the stack in any case.
    stack = _outputs_onto_stack(self._true, outputs, stack)
    self._true.state = both_states[0]
    self._false.state = both_states[1]
    return _make_singleitem_or_original(stack)
Example #2
0
  def forward(self, inputs):
    state = self.state
    observations = inputs
    if self._mode == 'collect':
      # Accumulate statistics only in the collect mode, i.e. when collecting
      # data using the agent.
      for observation in observations[:, -1]:  # (batch_size, time, ...)
        # Update statistics for each observation separately for simplicity.
        # Currently during data collection the batch size is 1 anyway.
        count = running_mean_and_variance_get_count(state)
        state = fastmath.cond(
            count < self._sample_limit,
            true_operand=(observation, state),
            true_fun=lambda args: running_mean_and_variance_update(*args),
            false_operand=None,
            false_fun=lambda _: state,
        )

    mean = running_mean_and_variance_get_mean(state)
    var = running_mean_and_variance_get_variance(state)
    norm_observations = (observations - mean) / (var ** 0.5 + self._epsilon)
    self.state = state
    return norm_observations
Example #3
0
  def forward(self, xs):
    self._validate_forward_inputs(xs)
    (step, layers_state) = self.state
    # Get N+1 rngs, N for running layers and one extra.
    rngs = _split_rngs(self.rng, self._n_layers + 1)
    rng0, rngs = rngs[0], rngs[1:]
    if not self.sublayers:  # No-op: leave args unchanged.
      self.state = (step + 1, layers_state)
      return xs

    # Prepare the stack and do some safety checks as in the parent class.
    stack = xs
    new_state = []
    n_layers = self._n_layers
    weights = self.weights
    if n_layers != 1 and len(weights) != n_layers:
      raise ValueError('number of weights ({}) not equal to number of layers '
                       '({})'.format(len(weights), n_layers))
    if n_layers != 1 and len(layers_state) != n_layers:
      raise ValueError('length of state ({}) not equal to number of layers '
                       '({})'.format(len(layers_state), n_layers))

    # TODO(chowdhery): try different strategies, also try running not all
    # layers backwards by using fastmath.stop_gradient where needed.

    # Calculate how many layers to run forward.
    if self._mode == 'train':
      # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after
      w_steps = float(self._skipping_warmup_steps)
      f_step = step.astype(jnp.float32)
      warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps)
      # low is the minimum number of layers to *not* skip, from n_layers to 0
      low = warmup * float(n_layers)
      # high should be so that (high - n_layers) / high = 1.0 - skip_fraction
      # because (high - n_layers) / high is the probability we're not skipping
      # (after warmup); so high - n_layers = high - high * skip_fraction
      high = float(n_layers) / self._skip_fraction
      # We want the same rng0 on all cores.
      if fastmath.device_count() > 1:
        rng0 = fastmath.psum(rng0, 'batch')
      n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high)
    else:
      n_forward_layers = float(n_layers)
    # Run layers skipping after a certain number.
    cur_layer_idx = 0.0
    for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs):
      inputs = _inputs_from_stack(layer, stack)
      def CondF(t):
        return layer.pure_fn(t[0], t[1], t[2], t[3])  # pylint: disable=cell-var-from-loop
      def PassF(t):
        return t[0], t[2]
      outputs, s = fastmath.cond(
          fastmath.lt(cur_layer_idx, n_forward_layers),
          CondF,
          PassF,
          (inputs, p, s, rng)
      )
      stack = _outputs_onto_stack(layer, outputs, stack)
      new_state.append(s)
      cur_layer_idx += 1.0
    self.state = (step + 1, new_state)
    return stack