def Init(shape, rng): """Returns orthogonalized random normal values with the given `shape`.""" # Have at least 2 elements in shape. cur_shape = list(shape) while len(cur_shape) < 2: cur_shape = [1] + cur_shape # Flatten the input shape with the last dimension remaining. n_rows = 1 for dim in cur_shape[:-1]: n_rows *= dim n_cols = cur_shape[-1] flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) # Generate a random matrix a = random.normal(rng, flat_shape, dtype=jnp.float32) # Compute the qr factorization q, r = jnp.linalg.qr(a) # Make Q uniform d = jnp.diag(r) q *= jnp.sign(d) # Transpose and reshape back q if needed. if n_rows < n_cols: q = jnp.transpose(q) q = jnp.reshape(q, shape) # Return scaled as requested. return stddev * q
def CombineHeadsPos(x, n_heads=1, **unused_kwargs): """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined. The positions are averaged as vectors. Args: x: input vector, concatenated (x0, p0, ..., xH, pH). n_heads: number of heads. Returns: the vector with combined xs and one with combined positions. """ seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, n_heads * d_head)) head_size = int(d_head) - POS_VECTOR_SIZE res, positions, idx = [], [], 0 for _ in range(n_heads): res.append(x[:, :, idx:idx + head_size]) idx += head_size positions.append(x[:, :, idx:idx + POS_VECTOR_SIZE]) idx += POS_VECTOR_SIZE combined_position = sum(positions) / float(len(positions)) return np.concatenate(res, axis=-1), combined_position
def compute_attention_heads(x): batch_size = x.shape[0] seqlen = x.shape[1] # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head)) # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head x = jnp.transpose(x, (0, 2, 1, 3)) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head return jnp.reshape(x, (-1, seqlen, d_head))
def create_weights_unbatched(self, input_signature, rng): d_model = input_signature[0].shape[-1] d_kv_antecedent = input_signature[1].shape[-1] rng_q, rng_k, rng_v, rng_o = jax.random.split(rng, 4) w_q = self._kernel_initializer((d_model, self.d_qk), rng_q) w_k = self._kernel_initializer((d_kv_antecedent, self.d_qk), rng_k) w_v = self._kernel_initializer((d_kv_antecedent, self.d_v), rng_v) w_o = np.transpose(self._kernel_initializer((d_model, self.d_v), rng_o)) return (w_q, w_k, w_v, w_o)
def forward(self, x, weights): seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, self._n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, self._n_heads * d_head)) return np.dot(x, weights)
def forward(self, x, weights): seqlen = x.shape[1] res = np.dot(x, weights) # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head)) # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head res = np.transpose(res, (0, 2, 1, 3)) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, self._d_head)) return res
def create_weights_unbatched(self, input_signature, rng): d_model = input_signature.shape[-1] rng_q, rng_k, rng_v, rng_o = jax.random.split(rng, 4) w_q = self._kernel_initializer((d_model, self.d_qk), rng_q) if not self.share_qk: w_k = self._kernel_initializer((d_model, self.d_qk), rng_k) w_v = self._kernel_initializer((d_model, self.d_v), rng_v) w_o = np.transpose(self._kernel_initializer((d_model, self.d_v), rng_o)) if self.share_qk: return (w_q, w_v, w_o) else: return (w_q, w_k, w_v, w_o)
def forward(self, inp, weights): """Reshape input to have heads dimension and concatenate positions there.""" x = inp[0] n_batches, seqlen = x.shape[0], x.shape[1] d_head = x.shape[-1] // self._n_heads res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head)) res = np.transpose(res, (0, 2, 1, 3)) # (batch, heads, len, depth) if self._n_pos == 1: # Just one position given, tile into each head. pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]] pos = inp[1][:, None, :, :] + np.zeros(pos_shape) # Add 0 to broadcast. else: # As many positions as heads, concatenate them in. pos = [p[:, None, :, :] for p in inp[1:]] pos = np.concatenate(pos, axis=1) res = np.concatenate([res, pos], axis=-1) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE)) return res
def JoinHeads(x): # pylint: disable=invalid-name return np.reshape(np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head))
def SplitHeads(x): return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3))
def compute_attention_output(x): seqlen = x.shape[1] x = jnp.reshape(x, (-1, n_heads, seqlen, d_head)) x = jnp.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head return jnp.reshape(x, (-1, seqlen, n_heads * d_head))