def shift(self, x): s = x.shape # tf.int_shape(x) y = tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0]]) y = tf.reshape(y, [s[0], s[1], s[3] + 1, s[2]]) y = tf.slice(y, [0, 0, 1, 0], [-1, -1, -1, -1]) y = tf.reshape(y, s) return y
def append_tok(self, idx, i, **kw): cfg = self.cfg k = 2 * cfg.beam_size b = tf.range(cfg.batch_size * k) // k b = tf.reshape(b, (cfg.batch_size, k)) beam = idx // cfg.num_toks sel = tf.stack([b, beam], axis=2) y = tf.gather_nd(self.tgt, sel) ii = tf.constant([i] * cfg.batch_size * k) ii = tf.reshape(ii, (cfg.batch_size, k)) sel = tf.stack([b, beam, ii], axis=2) u = tf.expand_dims(idx % cfg.num_toks, axis=2) tgt = tf.tensor_scatter_nd_update(y, sel, u) return tgt
def search(self, tgt, ctx, i=None): cfg = self.cfg unk = tf.equal(tgt, cfg.UNK) prior = tf.one_hot(tgt, cfg.num_toks, 0.0, utils.big_neg) if i is not None: unk = unk[:, i] prior = prior[:, i, :] if tf.reduce_all(unk) is True: logi = prior else: y = self.decode(tgt, ctx) if i is not None: y = y[:, i, :] sh = y.shape # tf.int_shape(y) y = tf.reshape(y, (-1, sh[-1])) y = self.logits(y) y = tf.reshape(y, sh[:-1] + y.shape[-1:]) u = tf.expand_dims(unk, axis=2) u = tf.broadcast_to(u, y.shape) logi = tf.where(u, y, prior) logp = y - tf.reduce_logsumexp(y, axis=-1, keepdims=True) return logp, logi, unk
def call(self, inputs): prev, x = inputs y = x if self.cmd: cfg = self.cfg for c in self.cmd: if c == 'a': y = prev + x elif c == 'z': y = prev + x * self.gamma elif c == 'n': if cfg.norm_type == 'layer': y = _layer_norm(self, x) elif cfg.norm_type == 'batch': y = self.batch(x) elif cfg.norm_type == 'l2': m = tf.reduce_mean(x, axis=-1, keepdims=True) n = tf.square(x - m) n = tf.reduce_sum(n, axis=-1, keepdims=True) y = (x - m) / tf.sqrt(n + cfg.norm_epsilon) y = y * self.gain + self.bias elif cfg.norm_type == 'group': sh = tf.int_shape(x) assert len(sh) == 4 and sh[-1] % cfg.num_groups == 0 gs = (cfg.num_groups, sh[-1] // cfg.num_groups) x = tf.reshape(x, sh[:-1] + gs) m, v = tf.moments(x, [1, 2, 4], keep_dims=True) y = (x - m) / tf.sqrt(v + cfg.group_epsilon) y = tf.reshape(y, sh) * self.gain + self.bias elif cfg.norm_type == 'noam': y = tf.cast_to_floatx(tf.int_shape(x)[-1]) y = tf.l2_normalize(x, axis=-1) * tf.sqrt(y) else: assert cfg.norm_type == 'none' else: assert c == 'd' y = self.drop(y) x = y return y
def top_logp(self, ctx, bias, i): cfg = self.cfg y = tf.zeros(( cfg.batch_size, cfg.beam_size, cfg.num_toks, )) y += tf.expand_dims(self.logp, axis=2) b = tf.range(cfg.batch_size) ii = tf.constant([i] * cfg.batch_size) for j in range(cfg.beam_size): jj = tf.constant([j] * cfg.batch_size) sel = tf.stack([b, jj, ii]) yj = self.to_logp(self.tgt[:, j, :], ctx, bias, i)[1] y = tf.tensor_scatter_nd_add(y, sel, yj) y = tf.reshape(y, (-1, cfg.beam_size * cfg.num_toks)) logp, idx = tf.top_k(y, k=2 * cfg.beam_size) return logp, idx
def join_heads(x): y = tf.transpose(x, perm=[0, 2, 1, 3]) s = y.shape # tf.int_shape(y) y = tf.reshape(y, (-1, s[1], s[2] * s[3])) return y
def split_heads(self, x): s = x.shape # tf.int_shape(x) n = self.cfg.num_heads y = tf.reshape(x, (-1, s[1], n, s[-1] // n)) y = tf.transpose(y, perm=[0, 2, 1, 3]) return y
def gather_beams(self, xs, beams, k): cfg = self.cfg b = tf.range(cfg.batch_size * k) // k b = tf.reshape(b, (cfg.batch_size, k)) sel = tf.stack([b, beams], axis=2) return nest.map_structure(lambda x: tf.gather_nd(x, sel), xs)