def PerformPositionOperations(pos, positions=None): """Gets pos and returns (q1, ..., q5).""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l) ]) add_values = np.array( [positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l) ]) sub_values = np.array( [positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) query_types = [ QueryPositionKV(), QueryPositionKV(keys=succ_keys, values=succ_values), QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values), QueryPositionKV(keys=add_keys, values=add_values, binary=True), QueryPositionKV(keys=sub_keys, values=sub_values, binary=True) ] return [qt @ pos for qt in query_types] # pylint: disable=syntax-error
def forward_with_state(self, x, weights, state, rng): batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = math.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = jax.random.randint( rng1, (batch_size, ), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = jax.random.randint(rng2, (batch_size, ), 0, max_pos - length) random_start *= start_from_nonzero positions += random_start[:, None] res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = weights[bn][i] pos_embeddings.append( cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = jax.random.randint(rng3, (batch_size, ), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) + jnp.zeros_like(x) return jnp.concatenate([x, res], axis=-1), state
def Interleave(inputs, **unused_kwargs): """Interleaves and flattens two serialized sequences. The first sequence can be longer by 1 than the second one. This is so we can interleave sequences of observations and actions, when there's 1 extra observation at the end. For serialized sequences [[x_1_1, ..., x_1_R1], ..., [x_L1_1, ..., x_L1_R1]] and [[y_1_1, ..., y_1_R2], ..., [y_L2_1, ..., y_L2_R2]], where L1 = L2 + 1, the result is [x_1_1, ..., x_1_R1, y_1_1, ..., y_1_R2, ..., x_L2_1, ..., x_L2_R1, y_L2_1, ..., y_L2_R2, x_L1_1, ..., x_L1_R1] (batch dimension omitted for clarity). Args: inputs: Pair of sequences of shapes (B, L1, R1) and (B, L2, R2), where B is batch size, L* is the length of the sequence and R* is the representation length of each element in the sequence. Returns: Interleaved sequence of shape (B, L1 * R1 + L2 * R2). """ (x, y) = inputs (batch_size, _, _) = x.shape (_, length, _) = y.shape assert x.shape[1] in (length, length + 1) reprs = np.concatenate((x[:, :length], y), axis=2) reprs = np.reshape(reprs, (batch_size, -1)) remainder = np.reshape(x[:, length:], (batch_size, -1)) return np.concatenate((reprs, remainder), axis=1)
def interleave(x, y): (batch_size, _, _) = x.shape (_, length, _) = y.shape assert x.shape[1] in (length, length + 1) reprs = jnp.concatenate((x[:, :length], y), axis=2) reprs = jnp.reshape(reprs, (batch_size, -1)) remainder = jnp.reshape(x[:, length:], (batch_size, -1)) return jnp.concatenate((reprs, remainder), axis=1)
def reverse(self, output, weights=(), state=(), new_state=(), **kwargs): del weights, kwargs x1_split = [] x2_split = [] for y in output: y1, y2 = np.split(y, 2, -1) x1_split.append(y1) x2_split.append(y2) x1 = np.concatenate(x1_split, self._axis) x2 = np.concatenate(x2_split, self._axis) return (x1, x2)
def forward(self, inputs, weights): x, gru_state = inputs # Dense layer on the concatenation of x and h. w1, b1, w2, b2 = weights y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 # Update and reset gates. u, r = jnp.split(math.sigmoid(y), 2, axis=-1) # Candidate. c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) return new_gru_state, new_gru_state
def forward_with_state(self, inputs, weights=layer_base.EMPTY_WEIGHTS, state=layer_base.EMPTY_STATE, rng=None, **kwargs): depth = inputs.shape[-1] if self._mode == 'predict': emb = self._get_embeddings(t=state) emb = emb[:, np.newaxis, :] state = state + 1 else: input_len = inputs.shape[-2] emb = self._get_embeddings(t=np.arange(input_len, dtype=np.int32)) # Leave batch axis as 1 for broadcasting: emb = emb[np.newaxis, :, :] emb = np.broadcast_to(emb, inputs.shape[:-1] + (3, )) # Replace the last num_features channels of input. inputs = np.concatenate([inputs[..., :-self.num_features], emb], -1) if inputs.shape[-1] > depth: logging.warning('dropping feature(s): %d down to %d', inputs.shape[-1], depth) inputs = inputs[..., -depth:] assert inputs.shape[-1] == depth, inputs.shape return inputs, state
def test_weights_state(self): layer = base.Fn( '2in2out', lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), n_out=2) weights, state = layer.new_weights_and_state(None) self.assertEmpty(weights) self.assertEmpty(state)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) emb = np.concatenate(embs, -1) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) return inputs + emb[:, state, :][:, None, :], state + 1 elif self._dropout == 0: return inputs + np.reshape(emb, inputs.shape), state else: noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def CombineHeadsPos(x, n_heads=1, **unused_kwargs): """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined. The positions are averaged as vectors. Args: x: input vector, concatenated (x0, p0, ..., xH, pH). n_heads: number of heads. Returns: the vector with combined xs and one with combined positions. """ seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, n_heads * d_head)) head_size = int(d_head) - POS_VECTOR_SIZE res, positions, idx = [], [], 0 for _ in range(n_heads): res.append(x[:, :, idx:idx + head_size]) idx += head_size positions.append(x[:, :, idx:idx + POS_VECTOR_SIZE]) idx += POS_VECTOR_SIZE combined_position = sum(positions) / float(len(positions)) return np.concatenate(res, axis=-1), combined_position
def forward(self, inputs, weights): state = self.state depth = inputs.shape[-1] if self._mode == 'predict': emb = self._get_embeddings(t=state) emb = emb[:, jnp.newaxis, :] state = state + 1 else: input_len = inputs.shape[-2] emb = self._get_embeddings( t=jnp.arange(input_len, dtype=jnp.int32)) # Leave batch axis as 1 for broadcasting: emb = emb[jnp.newaxis, :, :] emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3, )) # Replace the last num_features channels of input. inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1) if inputs.shape[-1] > depth: logging.warning('dropping feature(s): %d down to %d', inputs.shape[-1], depth) inputs = inputs[..., -depth:] assert inputs.shape[-1] == depth, inputs.shape self.state = state return inputs
def reverse(self, output, weights=(), state=(), new_state=(), **kwargs): del weights, kwargs if not isinstance(output, (list, tuple)): output = [output] x1_split = [] x2_split = [] for y in output: y1, y2 = np.split(y, 2, -1) x1_split.append(y1) x2_split.append(y2) x1 = np.concatenate(x1_split, self._axis) x2 = np.concatenate(x2_split, self._axis) return (x1, x2)
def forward(self, inputs, weights): x, lstm_state = inputs # LSTM state consists of c and h. c, h = jnp.split(lstm_state, 2, axis=-1) # Dense layer on the concatenation of x and h. w, b = weights y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = jnp.split(y, 4, axis=-1) new_c = c * math.sigmoid(f) + math.sigmoid(i) * jnp.tanh(j) new_h = jnp.tanh(new_c) * math.sigmoid(o) return new_h, jnp.concatenate([new_c, new_h], axis=-1)
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. if isinstance(self.n_buckets, int): assert self.n_buckets % 2 == 0 rot_size = self.n_buckets n_buckets = self.n_buckets else: # Factorize the hash if self.n_buckets is a list or tuple rot_size, n_buckets = 0, 1 for factor in self.n_buckets: assert factor % 2 == 0 rot_size += factor n_buckets *= factor rotations_shape = (vecs.shape[-1], self.n_hashes, rot_size // 2) rng = jax.lax.stop_gradient(jax.lax.tie_in(vecs, rng)) random_rotations = jax.random.normal(rng, rotations_shape).astype('float32') rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations) if isinstance(self.n_buckets, int) or len(self.n_buckets) == 1: rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) buckets = np.argmax(rotated_vecs, axis=-1) else: # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for factor in self.n_buckets: rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] cur_sum += factor // 2 rv = np.concatenate([rv, -rv], axis=-1) if buckets is None: buckets = np.argmax(rv, axis=-1) else: buckets += cur_product * np.argmax(rv, axis=-1) cur_product *= factor # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1, )) return buckets
def forward(self, inp, weights): """Reshape input to have heads dimension and concatenate positions there.""" x = inp[0] n_batches, seqlen = x.shape[0], x.shape[1] d_head = x.shape[-1] // self._n_heads res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head)) res = np.transpose(res, (0, 2, 1, 3)) # (batch, heads, len, depth) if self._n_pos == 1: # Just one position given, tile into each head. pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]] pos = inp[1][:, None, :, :] + np.zeros(pos_shape) # Add 0 to broadcast. else: # As many positions as heads, concatenate them in. pos = [p[:, None, :, :] for p in inp[1:]] pos = np.concatenate(pos, axis=1) res = np.concatenate([res, pos], axis=-1) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE)) return res
def forward(self, inputs, weights): del weights x1, x2 = inputs x1_split = np.split(x1, self._n_sections, self._axis) x2_split = np.split(x2, self._n_sections, self._axis) res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)] return tuple(res)
def _UpdateRow(x): # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,) row_e, row_d, row_mask_e = x # final_row - (L1+L2, H) final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0) # Find the last real token/vector of the encoder. e_idx = jnp.sum(row_mask_e, dtype=jnp.int32) # Starting after that index, update with the decoder row. return jax.lax.dynamic_update_slice(final_row, row_d, (e_idx, 0))
def test_fn_layer_weights_state(self): layer = Fn('2in2out', lambda x, y: (x + y, np.concatenate([x, y], axis=0)), n_out=2) input_signature = None weights, state = layer.new_weights_and_state(input_signature) self.assertIsNotNone(weights) self.assertIsNotNone(state) self.assertEmpty(weights) self.assertEmpty(state)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.concatenate(embs, -1) emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) emb = jax.lax.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) return inputs + emb, state + inputs.shape[1] elif self._dropout == 0: # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) # leads to memory blow-up on TPU. # emb = np.concatenate(embs, -1) # return inputs + np.reshape(emb, inputs.shape), state return inputs + np.concatenate([ np.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], )) for emb in embs ], -1), state else: emb = np.concatenate(embs, -1) noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs): """Query a table with a position vector.""" if keys is None: return x k = np.array(keys) v = np.array(values) q = x if binary: q = np.concatenate([x, x], axis=-1) return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
def test_fn_layer_example(self): layer = base.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0))) input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7))) expected_shape = ((2, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([2]), np.array([3])) x, xs = layer(inp) self.assertEqual(int(x), 5) self.assertEqual([int(y) for y in xs], [2, 3])
def test_fn_layer_difficult_n_out(self): with self.assertRaisesRegex(ValueError, 'n_out'): # Determining the output of this layer is hard with dummies. base.Fn(lambda x: np.concatencate([x, x], axis=4)) # Check that this layer works when n_out is set. layer = base.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1) input_signature = ShapeDtype((2, 1, 2, 2, 3)) expected_shape = (2, 1, 2, 2, 6) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def _get_embeddings(self, lo: int, hi: int, depth, rng=None): """Get embeddings float[length, depth]. Args: lo: where to start sampling hi: where to stop sampling depth: embedding depth rng: rng for random phase Returns: embeddings: float[length, depth] """ noise = self._get_noise(lo, hi, (depth + 1) // 2) # Make the stddev around 1 after 1/drift. noise = noise * self._drift**.5 t, c = onp.mgrid[lo:hi, :depth] # Make even channels cos, odd channels sin: c_div_2, c_mod_2 = divmod(c, 2) # Off-by-one correction for odd depth: drift = self._drift if depth > 2: drift = drift**(((depth + 1) // 2) / (depth // 2)) # Spend roughly half the frequencies on noise: freq = np.geomspace(.5, .5 * drift**2, num=(depth + 1) // 2)[c_div_2] cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4 assert cycles.shape == (hi - lo, depth), cycles.shape # Get random phases: if self._affine: assert rng is not None cycles = cycles + trax.math.random.uniform( rng, ( 1, depth, ), minval=0, maxval=1) # Convert from cycles to radians: embeddings = np.cos(np.pi * 2 * cycles) # Set the last channels to the time bin features: if self._time_bin_length is not None: inter_bin_idx, intra_bin_idx = divmod(t[:, -1:], self._time_bin_length) bin_parity = inter_bin_idx % 2 bin_fraction = intra_bin_idx / self._time_bin_length embeddings = np.concatenate([ embeddings[:, :-3], 1 / (1 + inter_bin_idx), bin_fraction, bin_parity.astype(np.float32), ], -1) assert embeddings.shape == (hi - lo, depth), embeddings.shape return embeddings
def look_adjacent(x, n_chunks_before, n_chunks_after): """Used to implement attention between consecutive chunks. Args: x: array of shape [n_chunks, chunk_len, ...] n_chunks_before: Number of previous chunks to attend to. n_chunks_after: Number of subsequent chunks to attend to. Returns: array of shape [n_chunks, N * chunk_len, ...], where N = (1 + n_chunks_before + n_chunks_after). """ if n_chunks_before == 0 and n_chunks_after == 0: return x slices = [] for i in range(-n_chunks_before, n_chunks_after + 1): if i == 0: slices.append(x) else: slices.append(np.concatenate([x[i:, ...], x[:i, ...]], axis=0)) return np.concatenate(slices, axis=1)
def f(x): # pylint: disable=invalid-name # x : [batch, 1, length, depth] x = np.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0) depth = x.shape[-1] // 3 assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, x.shape) xs = [ x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], x[:, :, 2:, 2 * depth:3 * depth] ] return np.concatenate(xs, axis=3)
def DiagonalGate(x): """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right.""" # x : [batch, 1, length, depth] x = np.pad( x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0) depth = x.shape[-1] // 3 assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, x.shape) xs = [ x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], x[:, :, 2:, 2 * depth:3 * depth] ] return np.concatenate(xs, axis=3)
def Deinterleave(inputs, x_size, y_size, **unused_kwargs): """Inverse of Interleave.""" reprs = inputs (batch_size, length) = reprs.shape[:2] shape_suffix = reprs.shape[2:] remainder_length = length % (x_size + y_size) remainder = reprs[:, None, -remainder_length:] reprs = reprs[:, :-remainder_length] reprs = np.reshape(reprs, (batch_size, -1, x_size + y_size) + shape_suffix) x_reprs = reprs[:, :, :x_size] y_reprs = reprs[:, :, x_size:] x_reprs = np.concatenate((x_reprs, remainder), axis=1) return (x_reprs, y_reprs)
def F(x): # TODO(afrozm): What to do in this case? if mode == 'predict': raise ValueError('MaskOfRightShiftedArray not implemented for predict.') mask = x != 0 if n_shifts == 0: return mask # Need to set (B, n_shifts, ...) section to True. trues_shape = (x.shape[0], n_shifts) + mask.shape[2:] trues = jnp.full(trues_shape, True) return jnp.concatenate([trues, mask[:, n_shifts:, ...]], axis=1)
def F(vec_e, vec_d, mask_e, mask_d): # pylint: disable=invalid-name L1 = mask_e.shape[1] L2 = mask_d.shape[1] # pylint: enable=invalid-name # [-(L1+L2), -L2) but with padding 0-ed out - (B, L1). mask_e_key = jnp.arange(-(L1 + L2), -L2) * mask_e # [-L2,0) but with padding 0-ed out - (B, L2). mask_d_key = jnp.arange(-L2, 0) * mask_d # Shape (B, L1+L2, H) enc_dec_concat = jnp.concatenate([vec_e, vec_d], axis=1) # Shape (B, L1+L2) mask_concat = jnp.concatenate([mask_e_key, mask_d_key], axis=1) # Make `mask_concat` the same shape as `enc_dec_concat` mask_concat = ( mask_concat[..., jnp.newaxis] + jnp.zeros_like(enc_dec_concat, dtype=jnp.int32)) # Sort on `mask_concat` so padding with key=0 goes to the right end, axis=1. _, enc_dec_pad = math.sort_key_val(mask_concat, enc_dec_concat, 1) return enc_dec_pad
def deinterleave(inputs): reprs = inputs (batch_size, length) = reprs.shape[:2] shape_suffix = reprs.shape[2:] remainder_length = length % (x_size + y_size) if remainder_length > 0: remainder = reprs[:, None, -remainder_length:] reprs = reprs[:, :-remainder_length] reprs = jnp.reshape(reprs, (batch_size, -1, x_size + y_size) + shape_suffix) x_reprs = reprs[:, :, :x_size] y_reprs = reprs[:, :, x_size:] if remainder_length > 0: x_reprs = jnp.concatenate((x_reprs, remainder), axis=1) return (x_reprs, y_reprs)