コード例 #1
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        del weights
        q, k, v = inputs
        if self._mode in ('train', 'eval'):
            mask_size = q.shape[-2]
            # Not all backends define np.tril. However, using onp.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if math.backend_name() == 'jax':
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=onp.bool_),
                               k=0)
            else:
                mask = onp.tril(onp.ones((1, mask_size, mask_size),
                                         dtype=onp.bool_),
                                k=0)
        else:
            assert self._mode == 'predict'
            state = _fast_inference_update_state(inputs, state)
            (k, v, mask, _) = state

        res = DotProductAttention(q,
                                  k,
                                  v,
                                  mask,
                                  dropout=self._dropout,
                                  mode=self._mode,
                                  rng=rng)
        return res, state
コード例 #2
0
    def forward_with_state(self, inputs, weights, state, rng):
        del weights
        q, k, v = inputs

        if self._mode == 'predict':
            state = _fast_inference_update_state(inputs, state)
            (k, v, mask, _) = state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if math.backend_name() == 'jax':
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res = DotProductAttention(q,
                                  k,
                                  v,
                                  mask,
                                  dropout=self._dropout,
                                  mode=self._mode,
                                  rng=rng)
        return res, state