def Init(shape, rng): """Returns orthogonalized random normal values with the given `shape`.""" # Have at least 2 elements in shape. cur_shape = list(shape) while len(cur_shape) < 2: cur_shape = [1] + cur_shape # Flatten the input shape with the last dimension remaining. n_rows = 1 for dim in cur_shape[:-1]: n_rows *= dim n_cols = cur_shape[-1] flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) # Generate a random matrix a = random.normal(rng, flat_shape, dtype=np.float32) # Compute the qr factorization q, r = np.linalg.qr(a) # Make Q uniform d = np.diag(r) q *= np.sign(d) # Transpose and reshape back q if needed. if n_rows < n_cols: q = np.transpose(q) q = np.reshape(q, shape) # Return scaled as requested. return stddev * q
def CombineHeadsPos(x, n_heads=1, **unused_kwargs): """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined. The positions are averaged as vectors. Args: x: input vector, concatenated (x0, p0, ..., xH, pH). n_heads: number of heads. Returns: the vector with combined xs and one with combined positions. """ seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, n_heads * d_head)) head_size = int(d_head) - POS_VECTOR_SIZE res, positions, idx = [], [], 0 for _ in range(n_heads): res.append(x[:, :, idx:idx+head_size]) idx += head_size positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE]) idx += POS_VECTOR_SIZE combined_position = sum(positions) / float(len(positions)) return np.concatenate(res, axis=-1), combined_position
def forward(self, x, weights): seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, self._n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, self._n_heads * d_head)) return np.dot(x, weights)
def forward(self, x, weights): seqlen = x.shape[1] res = np.dot(x, weights) # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head)) # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head res = np.transpose(res, (0, 2, 1, 3)) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, self._d_head)) return res
def forward(self, inp, weights): """Reshape input to have heads dimension and concatenate positions there.""" x = inp[0] n_batches, seqlen = x.shape[0], x.shape[1] d_head = x.shape[-1] // self._n_heads res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head)) res = np.transpose(res, (0, 2, 1, 3)) # (batch, heads, len, depth) if self._n_pos == 1: # Just one position given, tile into each head. pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]] pos = inp[1][:, None, :, :] + np.zeros(pos_shape) # Add 0 to broadcast. else: # As many positions as heads, concatenate them in. pos = [p[:, None, :, :] for p in inp[1:]] pos = np.concatenate(pos, axis=1) res = np.concatenate([res, pos], axis=-1) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE)) return res
def _sample_rotation(self, shape, vecs, rng): """Samples a rotation matrix, either randomly or based on `vecs`.""" if not self._data_rotation: return jax.random.normal(rng, shape).astype('float32') assert len(shape) == 3 unused_n_dim, n_hashes, r_div_2 = shape assert len(vecs.shape) == 2 n_vecs = vecs.shape[0] rng1, rng2 = backend.random.split(rng, num=2) # We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random. num_needed = 2 * n_hashes * r_div_2 if n_vecs < num_needed: # shape = (n_hashes, r_div_2) random_idxs_1 = jax.random.randint(rng1, (n_hashes, r_div_2), 0, n_vecs) random_idxs_2 = jax.random.randint(rng2, (n_hashes, r_div_2), 0, n_vecs) else: # Sample without replacement. shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs)) random_idxs = np.reshape(shuffled_indices[:num_needed], (2, n_hashes, r_div_2)) random_idxs_1 = random_idxs[0] random_idxs_2 = random_idxs[1] if self._data_rotation_farthest: # shape = (n_hashes * r_div_2, ) random_idxs_1 = np.reshape(random_idxs_1, (-1, )) random_vecs_1 = vecs[random_idxs_1] # Sample candidates for vec2s. rng, subrng = backend.random.split(rng) # shape = (self._data_rotation_farthest_num, n_hashes * r_div_2) candidate_idxs_2 = jax.random.randint( subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2), 0, n_vecs) candidate_vecs_2 = vecs[candidate_idxs_2] # shape = candidate_idxs_2.shape distances = -np.abs( np.einsum('hd,chd->ch', random_vecs_1, candidate_vecs_2)) # shape = (n_hashes * r_div_2,) farthest_idxs = np.argmax(distances, axis=0) # candidate_vecs_2.shape random_vecs_2 = candidate_vecs_2[farthest_idxs, np.arange(n_hashes * r_div_2)] # reshape to (n_hashes, r_div_2, n_dim) random_vecs_1 = np.reshape(random_vecs_1, (n_hashes, r_div_2, -1)) random_vecs_2 = np.reshape(random_vecs_2, (n_hashes, r_div_2, -1)) else: # shape = (n_hashes, r_div_2, n_dim) random_vecs_1 = vecs[random_idxs_1] random_vecs_2 = vecs[random_idxs_2] # shape = (n_dim, n_hashes, r_div_2) return np.transpose(random_vecs_2 - random_vecs_1, axes=[2, 0, 1])
def JoinHeads(x): # pylint: disable=invalid-name return np.reshape(np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head))
def SplitHeads(x): return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3))