コード例 #1
0
    def forward_with_state(self, inputs, weights, state, rng):
        if isinstance(inputs, list):
            inputs = tuple(inputs)  # so that inputs structure matches outputs
        n_carry = self._n_carry

        def scannable_fn(x, carry_and_state):  # pylint: disable=invalid-name
            carry, state = carry_and_state
            x_and_carry = x + carry if n_carry > 0 else x
            res, new_state = self.sublayer.forward_with_state(
                x_and_carry, weights, state, rng)
            if n_carry > 0:
                return (res[:-n_carry], (res[-n_carry:], new_state))
            else:
                return (res, ([], new_state))

        if n_carry > 0:
            xs = inputs[:-n_carry]  # Split input stack into inputs and carry.
            init = (inputs[-n_carry:], state)
        else:
            xs, init = inputs, ([], state)
        ys, (carry, new_state) = math.scan(scannable_fn,
                                           xs,
                                           init,
                                           axis=self._axis,
                                           remat=self._remat)
        res = ys + carry if n_carry > 0 else ys
        return res, new_state  # Put outputs and carry back on stack.
コード例 #2
0
ファイル: combinators.py プロジェクト: yasyaindra/trax
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           **kwargs):
        n_carry = self._n_carry

        def scannable_fn(x, carry_and_state):  # pylint: disable=invalid-name
            carry, state = carry_and_state
            x_and_carry = x + carry if n_carry > 0 else x
            res, new_state = self.sublayer.forward_with_state(x_and_carry,
                                                              weights=weights,
                                                              state=state,
                                                              **kwargs)
            if n_carry > 0:
                return (res[:-n_carry], (res[-n_carry:], new_state))
            else:
                return (res, ([], new_state))

        if n_carry > 0:
            xs = inputs[:-n_carry]  # Split input stack into inputs and carry.
            init = (inputs[-n_carry:], state)
        else:
            xs, init = inputs, ([], state)
        ys, (carry, new_state) = math.scan(scannable_fn,
                                           xs,
                                           init,
                                           axis=self._axis,
                                           remat=self._remat)
        res = ys + carry if n_carry > 0 else ys
        return res, new_state  # Put outputs and carry back on stack.