def not_done(self, i): y = self.score * tf.cast(self.flag, tf.floatx()) y = tf.reduce_min(y, axis=1) fs = tf.reduce_any(self.flags, axis=1) old = y + (1. - tf.cast(fs, tf.floatx())) * utils.big_neg n = tf.int_shape(self.tgt)[-1] new = self.logp[:, 0] / self.penalty(n) done = tf.reduce_all(tf.greater(old, new)) return tf.logical_and(tf.less(i, n), tf.logical_not(done))
def pos_timing_2(dim, end, p_max, p_min, p_start): t = tf.range(end, dtype=tf.floatx()) + p_start assert dim % 2 == 0 n = dim // 2 f = np.log(p_max / p_min) / max(n - 1, 1) f = tf.range(n, dtype=tf.floatx()) * -f f = tf.exp(f) * p_min t = tf.einsum('i,d->id', t, f) return tf.concat([tf.sin(t), tf.cos(t)], axis=-1)
def call(self, inputs): cfg = self.cfg x, tgt = inputs if cfg.brackets: y = tf.zeros_like(tgt, dtype=tf.floatx()) bs = cfg.brackets + [cfg.num_toks] b = 0 for i, e in enumerate(bs): msk = (tgt >= (b or 1)) & (tgt < e) mt = tf.boolean_mask(tgt, msk) - b gi = tf.stack([tf.range(tf.shape(mt)[0]), mt]) if i == 0: logp = tf.log_softmax(self.logits(x, i)) mp = tf.boolean_mask(logp, msk) u = tf.gather_nd(mp, gi) else: mp = tf.boolean_mask(logp, msk) u = mp[:, bs[i - 1]] mc = tf.boolean_mask(x, msk)[None] mp = tf.log_softmax(self.logits(mc, i)) mp = tf.squeeze(mp, 0) u += tf.gather_nd(mp, gi) y = tf.tensor_scatter_nd_add(y, tf.where(msk), -u) b = e else: y = self.logits(x) # f = tf.SparseCategoricalAccuracy # self.add_metric(f(name='acc')(tgt, y)) f = tf.sparse_softmax_cross_entropy_with_logits loss = f(labels=tgt, logits=y) # self.add_loss(lambda: tf.reduce_mean(loss)) return y
def add_resource(self, name, shape, **kw): kw.setdefault('dtype', tf.floatx()) kw.setdefault('trainable', False) kw.setdefault('use_resource', True) kw.setdefault('initializer', tf.zeros_initializer()) kw.setdefault('aggregation', tf.VariableAggregation.NONE) kw.setdefault('synchronization', tf.VariableSynchronization.NONE) return self.add_variable(name, shape, **kw)
def add_weight(self, name, shape, **kw): kw.setdefault('dtype', tf.floatx()) cfg = self.cfg if hasattr(cfg, 'init_stddev'): kw.setdefault('initializer', tf.TruncatedNormal(stddev=cfg.init_stddev)) if hasattr(cfg, 'regular_l1') and hasattr(cfg, 'regular_l2'): kw.setdefault('regularizer', tf.L1L2(cfg.regular_l1, cfg.regular_l2)) return super().add_weight(name, shape, **kw)
def to_scores(self, qk, mask, v): b = 0 if mask is not None: b = tf.logical_not(mask) b = tf.cast(b, tf.floatx()) * utils.big_neg() if self.proxim_b is not None: b += self.proxim_b b = b[:, None, :, None] y = tf.softmax(qk * self.scale + b) cfg = self.cfg y = self.drop(y, cfg.drop_attn or cfg.drop_hidden) y = tf.einsum('bnij,bnjv->bniv', y, v) return y
def proximity(end): y = tf.range(end, dtype=tf.floatx()) y = y[None, ] - y[:, None] y = -tf.log1p(tf.abs(y)) y = y[None, None, ] return y
def pos_timing(dim, end): t = tf.range(end - 1, -1, -1, dtype=tf.floatx()) f = tf.range(0, dim, 2.0, dtype=tf.floatx()) f = 1 / (10000**(f / dim)) t = tf.einsum('i,d->id', t, f) return tf.concat([tf.sin(t), tf.cos(t)], axis=-1)
def big_neg(): f = tf.floatx() return tf.float16.min if f == 'float16' else -1e9
def add_bias(self, name, shape, **kw): kw.setdefault('dtype', tf.floatx()) kw.setdefault('initializer', tf.zeros_initializer()) return super().add_weight(name, shape, **kw)
def penalty(self, n): n = tf.cast(n, tf.floatx()) y = tf.pow(((5. + n) / 6.), self.cfg.beam_alpha) return y
def top_out(self, x, lp, i): cfg = self.cfg score = lp / self.penalty(i + 1) flag = tf.equal(x[:, :, -1], cfg.END) score += (1. - tf.cast(flag, tf.floatx())) * utils.big_neg return self.top_beams([x, score, flag], score)
def top_tgt(self, x, lp): cfg = self.cfg fs = tf.equal(x[:, :, -1], cfg.END) lp += tf.cast(fs, tf.floatx()) * utils.big_neg return self.top_beams([x, lp], lp)