Example #1
0
 def to_qk_with_pos(self, q, k):
     b = self.pos_x_b[:, None, :, None]
     y = tf.einsum('bnik,bnjk->bnij', q + b, k)
     p = tf.einsum('ih,hk->ik', self.pos_tim, self.pos_w)
     p = self.split_heads(p)[None, ]
     b = self.pos_b[:, None, :, None]
     p = tf.einsum('bnik,bnjk->bnij', q + b, p)
     y += self.shift(p)
     return y
Example #2
0
 def lookup(self, x, i):
     t = self.table_ws[i]
     if self.one_hot:
         y = tf.one_hot(x, tf.shape(t)[0], axis=-1)
         y = tf.einsum('np,in->ip', t, y)
     else:
         y = tf.embedding_lookup(t, x)
     a = self.adapt_ws[i]
     if a is not None:
         y = tf.einsum('ip,ph->ih', y, a)
     return y
Example #3
0
 def logits(self, x, i=None):
     y = x
     a = self.adapt_ws[i or 0]
     if a is not None:
         y = tf.einsum('bih,ph->bip', y, a)
     t = self.table_ws[i or 0]
     b = self.table_bs[i or 0]
     if i == 0:
         t = tf.concat([t, self.clust_w], 0)
         b = tf.concat([b, self.clust_b], 0)
     y = tf.einsum('bie,ne->bin', y, t) + b
     return y
Example #4
0
 def call(self, inputs, mask=None):
     x, typ = inputs
     y = typ
     if mask is not None:
         y *= tf.cast(mask, typ.dtype)
     y = tf.one_hot(y, self.cfg.tok_types)
     return x + tf.einsum('bie,eh->bih', y, self.typ_w)
Example #5
0
def pos_timing_2(dim, end, p_max, p_min, p_start):
    t = tf.range(end, dtype=tf.floatx()) + p_start
    assert dim % 2 == 0
    n = dim // 2
    f = np.log(p_max / p_min) / max(n - 1, 1)
    f = tf.range(n, dtype=tf.floatx()) * -f
    f = tf.exp(f) * p_min
    t = tf.einsum('i,d->id', t, f)
    return tf.concat([tf.sin(t), tf.cos(t)], axis=-1)
Example #6
0
 def call(self, inputs, mask=None):
     x, ctx = inputs[0], inputs[1] if len(inputs) > 1 else None
     y = x if ctx is None else tf.concat([ctx, x], axis=1)
     y = self.pre([y, y])
     if self.v_w is None:
         y = v = tf.einsum('bih,hk->bik', y, self.qkv_w)
     else:
         y = tf.einsum('bih,hk->bik', y, self.qk_w)
         v = tf.einsum('bih,hv->biv', v, self.v_w)
     xlen = x.shape[1]  # tf.int_shape(x)[1]
     q = self.split_heads(y[:, -xlen:, :])
     k = self.split_heads(y)
     if self.pos_tim is None:
         qk = tf.einsum('bnik,bnjk->bnij', q, k)
     else:
         qk = self.to_qk_with_pos(q, k)
     v = self.split_heads(v)
     y = self.to_scores(qk, mask, v)
     y = self.join_heads(y)
     y = tf.einsum('biv,vh->bih', y, self.out_w)
     y = self.post([x, y])
     return y
Example #7
0
 def to_scores(self, qk, mask, v):
     b = 0
     if mask is not None:
         b = tf.logical_not(mask)
         b = tf.cast(b, tf.floatx()) * utils.big_neg()
         if self.proxim_b is not None:
             b += self.proxim_b
         b = b[:, None, :, None]
     y = tf.softmax(qk * self.scale + b)
     cfg = self.cfg
     y = self.drop(y, cfg.drop_attn or cfg.drop_hidden)
     y = tf.einsum('bnij,bnjv->bniv', y, v)
     return y
Example #8
0
def pos_timing(dim, end):
    t = tf.range(end - 1, -1, -1, dtype=tf.floatx())
    f = tf.range(0, dim, 2.0, dtype=tf.floatx())
    f = 1 / (10000**(f / dim))
    t = tf.einsum('i,d->id', t, f)
    return tf.concat([tf.sin(t), tf.cos(t)], axis=-1)