예제 #1
0
 def call(self, inputs, params=(), rng=None, **kwargs):
   del params
   q, k, v = inputs
   mask_size = q.shape[-2]
   mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
   res = tl.DotProductAttention(
       q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng)
   return res
예제 #2
0
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs):
    """Query a table with a position vector."""
    if keys is None:
        return x
    k = np.array(keys)
    v = np.array(values)
    q = x
    if binary:
        q = np.concatenate([x, x], axis=-1)
    return tl.DotProductAttention(q, k, v, None, None, None, None)