Example #1
0
 def shift(self, x):
     s = x.shape  # tf.int_shape(x)
     y = tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0]])
     y = tf.reshape(y, [s[0], s[1], s[3] + 1, s[2]])
     y = tf.slice(y, [0, 0, 1, 0], [-1, -1, -1, -1])
     y = tf.reshape(y, s)
     return y
Example #2
0
 def append_tok(self, idx, i, **kw):
     cfg = self.cfg
     k = 2 * cfg.beam_size
     b = tf.range(cfg.batch_size * k) // k
     b = tf.reshape(b, (cfg.batch_size, k))
     beam = idx // cfg.num_toks
     sel = tf.stack([b, beam], axis=2)
     y = tf.gather_nd(self.tgt, sel)
     ii = tf.constant([i] * cfg.batch_size * k)
     ii = tf.reshape(ii, (cfg.batch_size, k))
     sel = tf.stack([b, beam, ii], axis=2)
     u = tf.expand_dims(idx % cfg.num_toks, axis=2)
     tgt = tf.tensor_scatter_nd_update(y, sel, u)
     return tgt
Example #3
0
 def search(self, tgt, ctx, i=None):
     cfg = self.cfg
     unk = tf.equal(tgt, cfg.UNK)
     prior = tf.one_hot(tgt, cfg.num_toks, 0.0, utils.big_neg)
     if i is not None:
         unk = unk[:, i]
         prior = prior[:, i, :]
     if tf.reduce_all(unk) is True:
         logi = prior
     else:
         y = self.decode(tgt, ctx)
         if i is not None:
             y = y[:, i, :]
         sh = y.shape  # tf.int_shape(y)
         y = tf.reshape(y, (-1, sh[-1]))
         y = self.logits(y)
         y = tf.reshape(y, sh[:-1] + y.shape[-1:])
         u = tf.expand_dims(unk, axis=2)
         u = tf.broadcast_to(u, y.shape)
         logi = tf.where(u, y, prior)
     logp = y - tf.reduce_logsumexp(y, axis=-1, keepdims=True)
     return logp, logi, unk
Example #4
0
 def call(self, inputs):
     prev, x = inputs
     y = x
     if self.cmd:
         cfg = self.cfg
         for c in self.cmd:
             if c == 'a':
                 y = prev + x
             elif c == 'z':
                 y = prev + x * self.gamma
             elif c == 'n':
                 if cfg.norm_type == 'layer':
                     y = _layer_norm(self, x)
                 elif cfg.norm_type == 'batch':
                     y = self.batch(x)
                 elif cfg.norm_type == 'l2':
                     m = tf.reduce_mean(x, axis=-1, keepdims=True)
                     n = tf.square(x - m)
                     n = tf.reduce_sum(n, axis=-1, keepdims=True)
                     y = (x - m) / tf.sqrt(n + cfg.norm_epsilon)
                     y = y * self.gain + self.bias
                 elif cfg.norm_type == 'group':
                     sh = tf.int_shape(x)
                     assert len(sh) == 4 and sh[-1] % cfg.num_groups == 0
                     gs = (cfg.num_groups, sh[-1] // cfg.num_groups)
                     x = tf.reshape(x, sh[:-1] + gs)
                     m, v = tf.moments(x, [1, 2, 4], keep_dims=True)
                     y = (x - m) / tf.sqrt(v + cfg.group_epsilon)
                     y = tf.reshape(y, sh) * self.gain + self.bias
                 elif cfg.norm_type == 'noam':
                     y = tf.cast_to_floatx(tf.int_shape(x)[-1])
                     y = tf.l2_normalize(x, axis=-1) * tf.sqrt(y)
                 else:
                     assert cfg.norm_type == 'none'
             else:
                 assert c == 'd'
                 y = self.drop(y)
             x = y
     return y
Example #5
0
 def top_logp(self, ctx, bias, i):
     cfg = self.cfg
     y = tf.zeros((
         cfg.batch_size,
         cfg.beam_size,
         cfg.num_toks,
     ))
     y += tf.expand_dims(self.logp, axis=2)
     b = tf.range(cfg.batch_size)
     ii = tf.constant([i] * cfg.batch_size)
     for j in range(cfg.beam_size):
         jj = tf.constant([j] * cfg.batch_size)
         sel = tf.stack([b, jj, ii])
         yj = self.to_logp(self.tgt[:, j, :], ctx, bias, i)[1]
         y = tf.tensor_scatter_nd_add(y, sel, yj)
     y = tf.reshape(y, (-1, cfg.beam_size * cfg.num_toks))
     logp, idx = tf.top_k(y, k=2 * cfg.beam_size)
     return logp, idx
Example #6
0
 def join_heads(x):
     y = tf.transpose(x, perm=[0, 2, 1, 3])
     s = y.shape  # tf.int_shape(y)
     y = tf.reshape(y, (-1, s[1], s[2] * s[3]))
     return y
Example #7
0
 def split_heads(self, x):
     s = x.shape  # tf.int_shape(x)
     n = self.cfg.num_heads
     y = tf.reshape(x, (-1, s[1], n, s[-1] // n))
     y = tf.transpose(y, perm=[0, 2, 1, 3])
     return y
Example #8
0
 def gather_beams(self, xs, beams, k):
     cfg = self.cfg
     b = tf.range(cfg.batch_size * k) // k
     b = tf.reshape(b, (cfg.batch_size, k))
     sel = tf.stack([b, beams], axis=2)
     return nest.map_structure(lambda x: tf.gather_nd(x, sel), xs)