def call(self, inputs):
     """
     Args:
         inputs: it is list of ID and positions of tokens and their mask.
                 tokens shape = (batch size, context length, 2 (IDs and positions))
                 masks shape = (batch size, context length)
         
     Returns:
         logits: shape = (batch size, context length, vocab size)
         losses: shape = (batch size, )
     """
     tokens = tf.reshape(inputs[0], [-1, self.n_ctx, 2])
     masks = tf.reshape(inputs[1], [-1, self.n_ctx])
     hidden1 = self.embed(tokens)
     self.embed.we = dropout(self.embed.we, self.embd_pdrop, self.train)
     hidden2 = self.transformer_stack(hidden1)
     hidden3 = tf.reshape(hidden2, [-1, self.n_embd])
     logits = tf.reshape(
         tf.matmul(hidden3,
                   self.embed.we[:self.n_vocab, :],
                   transpose_b=True), [-1, self.n_ctx, self.n_vocab])
     logits_truncated = tf.reshape(logits[:, :-1], [-1, self.n_vocab])
     losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
         logits=logits_truncated, labels=tf.reshape(tokens[:, 1:, 0], [-1]))
     losses = tf.reshape(losses,
                         [shape_list(tokens)[0],
                          shape_list(tokens)[1] - 1])
     losses = tf.reduce_sum(losses * masks[:, 1:], 1) / tf.reduce_sum(
         masks[:, 1:], 1)
     return logits, losses
 def _attn(self, q, k, v):
     w = tf.matmul(q, k)
     if self.scale:
         n_state = shape_list(v)[-1]
         w = w * tf.rsqrt(tf.cast(n_state, tf.float32))
     w = self.mask_attn_weights(w)
     w = tf.nn.softmax(w)
     w = dropout(w, self.attn_pdrop, self.train)
     a = tf.matmul(w, v)
     return a
 def call(self, inputs, **kwargs):
     if self.rf == 1:
         c = tf.reshape(
             tf.matmul(tf.reshape(inputs, [-1, self.nx]),
                       tf.reshape(self.w, [-1, self.nf])) + self.b,
             shape_list(inputs)[:-1] + [self.nf])
     else:
         c = tf.nn.conv1d(
             value=inputs, filters=self.w, stride=1,
             padding='VALID') + self.b
     return c
 def mask_attn_weights(self, w):
     n = shape_list(w)[-1]
     b = tf.matrix_band_part(tf.ones([n, n]), -1, 0)
     b = tf.reshape(b, [1, 1, n, n])
     w = w * b + -1e9 * (1 - b)
     return w
 def merge_states(self, x):
     x_shape = shape_list(x)
     new_x_shape = x_shape[:-2] + [np.prod(x_shape[-2:])]
     return tf.reshape(x, new_x_shape)
 def split_states(self, x, n):
     x_shape = shape_list(x)
     m = x_shape[-1]
     new_x_shape = x_shape[:-1] + [n, m // n]
     return tf.reshape(x, new_x_shape)