def forward_with_state(self, inputs, weights, state, rng): if isinstance(inputs, list): inputs = tuple(inputs) # so that inputs structure matches outputs n_carry = self._n_carry def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name carry, state = carry_and_state x_and_carry = x + carry if n_carry > 0 else x res, new_state = self.sublayer.forward_with_state( x_and_carry, weights, state, rng) if n_carry > 0: return (res[:-n_carry], (res[-n_carry:], new_state)) else: return (res, ([], new_state)) if n_carry > 0: xs = inputs[:-n_carry] # Split input stack into inputs and carry. init = (inputs[-n_carry:], state) else: xs, init = inputs, ([], state) ys, (carry, new_state) = math.scan(scannable_fn, xs, init, axis=self._axis, remat=self._remat) res = ys + carry if n_carry > 0 else ys return res, new_state # Put outputs and carry back on stack.
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, **kwargs): n_carry = self._n_carry def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name carry, state = carry_and_state x_and_carry = x + carry if n_carry > 0 else x res, new_state = self.sublayer.forward_with_state(x_and_carry, weights=weights, state=state, **kwargs) if n_carry > 0: return (res[:-n_carry], (res[-n_carry:], new_state)) else: return (res, ([], new_state)) if n_carry > 0: xs = inputs[:-n_carry] # Split input stack into inputs and carry. init = (inputs[-n_carry:], state) else: xs, init = inputs, ([], state) ys, (carry, new_state) = math.scan(scannable_fn, xs, init, axis=self._axis, remat=self._remat) res = ys + carry if n_carry > 0 else ys return res, new_state # Put outputs and carry back on stack.