def to_qk_with_pos(self, q, k): b = self.pos_x_b[:, None, :, None] y = tf.einsum('bnik,bnjk->bnij', q + b, k) p = tf.einsum('ih,hk->ik', self.pos_tim, self.pos_w) p = self.split_heads(p)[None, ] b = self.pos_b[:, None, :, None] p = tf.einsum('bnik,bnjk->bnij', q + b, p) y += self.shift(p) return y
def lookup(self, x, i): t = self.table_ws[i] if self.one_hot: y = tf.one_hot(x, tf.shape(t)[0], axis=-1) y = tf.einsum('np,in->ip', t, y) else: y = tf.embedding_lookup(t, x) a = self.adapt_ws[i] if a is not None: y = tf.einsum('ip,ph->ih', y, a) return y
def logits(self, x, i=None): y = x a = self.adapt_ws[i or 0] if a is not None: y = tf.einsum('bih,ph->bip', y, a) t = self.table_ws[i or 0] b = self.table_bs[i or 0] if i == 0: t = tf.concat([t, self.clust_w], 0) b = tf.concat([b, self.clust_b], 0) y = tf.einsum('bie,ne->bin', y, t) + b return y
def call(self, inputs, mask=None): x, typ = inputs y = typ if mask is not None: y *= tf.cast(mask, typ.dtype) y = tf.one_hot(y, self.cfg.tok_types) return x + tf.einsum('bie,eh->bih', y, self.typ_w)
def pos_timing_2(dim, end, p_max, p_min, p_start): t = tf.range(end, dtype=tf.floatx()) + p_start assert dim % 2 == 0 n = dim // 2 f = np.log(p_max / p_min) / max(n - 1, 1) f = tf.range(n, dtype=tf.floatx()) * -f f = tf.exp(f) * p_min t = tf.einsum('i,d->id', t, f) return tf.concat([tf.sin(t), tf.cos(t)], axis=-1)
def call(self, inputs, mask=None): x, ctx = inputs[0], inputs[1] if len(inputs) > 1 else None y = x if ctx is None else tf.concat([ctx, x], axis=1) y = self.pre([y, y]) if self.v_w is None: y = v = tf.einsum('bih,hk->bik', y, self.qkv_w) else: y = tf.einsum('bih,hk->bik', y, self.qk_w) v = tf.einsum('bih,hv->biv', v, self.v_w) xlen = x.shape[1] # tf.int_shape(x)[1] q = self.split_heads(y[:, -xlen:, :]) k = self.split_heads(y) if self.pos_tim is None: qk = tf.einsum('bnik,bnjk->bnij', q, k) else: qk = self.to_qk_with_pos(q, k) v = self.split_heads(v) y = self.to_scores(qk, mask, v) y = self.join_heads(y) y = tf.einsum('biv,vh->bih', y, self.out_w) y = self.post([x, y]) return y
def to_scores(self, qk, mask, v): b = 0 if mask is not None: b = tf.logical_not(mask) b = tf.cast(b, tf.floatx()) * utils.big_neg() if self.proxim_b is not None: b += self.proxim_b b = b[:, None, :, None] y = tf.softmax(qk * self.scale + b) cfg = self.cfg y = self.drop(y, cfg.drop_attn or cfg.drop_hidden) y = tf.einsum('bnij,bnjv->bniv', y, v) return y
def pos_timing(dim, end): t = tf.range(end - 1, -1, -1, dtype=tf.floatx()) f = tf.range(0, dim, 2.0, dtype=tf.floatx()) f = 1 / (10000**(f / dim)) t = tf.einsum('i,d->id', t, f) return tf.concat([tf.sin(t), tf.cos(t)], axis=-1)