def call(self, inputs): """ Args: inputs: it is list of ID and positions of tokens and their mask. tokens shape = (batch size, context length, 2 (IDs and positions)) masks shape = (batch size, context length) Returns: logits: shape = (batch size, context length, vocab size) losses: shape = (batch size, ) """ tokens = tf.reshape(inputs[0], [-1, self.n_ctx, 2]) masks = tf.reshape(inputs[1], [-1, self.n_ctx]) hidden1 = self.embed(tokens) self.embed.we = dropout(self.embed.we, self.embd_pdrop, self.train) hidden2 = self.transformer_stack(hidden1) hidden3 = tf.reshape(hidden2, [-1, self.n_embd]) logits = tf.reshape( tf.matmul(hidden3, self.embed.we[:self.n_vocab, :], transpose_b=True), [-1, self.n_ctx, self.n_vocab]) logits_truncated = tf.reshape(logits[:, :-1], [-1, self.n_vocab]) losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_truncated, labels=tf.reshape(tokens[:, 1:, 0], [-1])) losses = tf.reshape(losses, [shape_list(tokens)[0], shape_list(tokens)[1] - 1]) losses = tf.reduce_sum(losses * masks[:, 1:], 1) / tf.reduce_sum( masks[:, 1:], 1) return logits, losses
def _attn(self, q, k, v): w = tf.matmul(q, k) if self.scale: n_state = shape_list(v)[-1] w = w * tf.rsqrt(tf.cast(n_state, tf.float32)) w = self.mask_attn_weights(w) w = tf.nn.softmax(w) w = dropout(w, self.attn_pdrop, self.train) a = tf.matmul(w, v) return a
def call(self, inputs, **kwargs): if self.rf == 1: c = tf.reshape( tf.matmul(tf.reshape(inputs, [-1, self.nx]), tf.reshape(self.w, [-1, self.nf])) + self.b, shape_list(inputs)[:-1] + [self.nf]) else: c = tf.nn.conv1d( value=inputs, filters=self.w, stride=1, padding='VALID') + self.b return c
def mask_attn_weights(self, w): n = shape_list(w)[-1] b = tf.matrix_band_part(tf.ones([n, n]), -1, 0) b = tf.reshape(b, [1, 1, n, n]) w = w * b + -1e9 * (1 - b) return w
def merge_states(self, x): x_shape = shape_list(x) new_x_shape = x_shape[:-2] + [np.prod(x_shape[-2:])] return tf.reshape(x, new_x_shape)
def split_states(self, x, n): x_shape = shape_list(x) m = x_shape[-1] new_x_shape = x_shape[:-1] + [n, m // n] return tf.reshape(x, new_x_shape)