def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE, **kwargs):
   xs = inputs[:-self._n_carry]  # Split input stack into inputs and carry.
   init = (inputs[-self._n_carry:], state)
   def LayerFn(x, carry_and_state):
     carry, state = carry_and_state
     res, new_state = self._layer.forward_with_state(
         x + carry, weights=weights, state=state, **kwargs)
     return (res[:-self._n_carry], (res[-self._n_carry:], new_state))
   ys, (carry, new_state) = backend.scan(LayerFn, xs, init, axis=self._axis)
   return ys + carry, new_state  # Put outputs and carry back on stack.
Exemplo n.º 2
0
  def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
                         state=base.EMPTY_STATE, **kwargs):
    n_carry = self._n_carry
    def scannable_fn(x, carry_and_state):  # pylint: disable=invalid-name
      carry, state = carry_and_state
      res, new_state = self.sublayer.forward_with_state(
          x + carry, weights=weights, state=state, **kwargs)
      return (res[:-n_carry], (res[-n_carry:], new_state))

    xs = inputs[:-n_carry]  # Split input stack into inputs and carry.
    init = (inputs[-n_carry:], state)
    ys, (carry, new_state) = backend.scan(scannable_fn, xs, init,
                                          axis=self._axis)
    return ys + carry, new_state  # Put outputs and carry back on stack.