def _test_fast_inference(self, model_cls, x, input_signature, common_kwargs, *test_kwargs): ref_model = model_cls(use_reference_code=True, mode='eval', **common_kwargs) weights, state = ref_model.init(input_signature) ref_out, _ = ref_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) def get_slice(pytree, i): def get_slice_for_val(x): if isinstance(x, shapes.ShapeDtype): return shapes.ShapeDtype(shape=x.shape[:1] + (1, ) + x.shape[2:], dtype=x.dtype) else: return x[:, i:i + 1] return jax.tree_map(get_slice_for_val, pytree) seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1] for kwargs in test_kwargs: test_model = model_cls(mode='predict', **common_kwargs, **kwargs) cur_state = test_model.init(get_slice(input_signature, 0))[1] out = [] for i in range(seqlen): cur_out, cur_state = test_model.pure_fn( get_slice(x, i), weights, cur_state, jax.random.PRNGKey(0)) out.append(cur_out) out = jnp.concatenate(out, axis=1) self.assertAllClose(out, ref_out, rtol=1e-3, atol=1e-3)
def forward(self, x): rng = self.rng batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = fastmath.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = jax.random.randint( rng1, (batch_size, ), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = jax.random.randint(rng2, (batch_size, ), 0, max_pos - length) random_start *= start_from_nonzero positions += random_start[:, None] res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = self.weights[bn][i] pos_embeddings.append( cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = jax.random.randint(rng3, (batch_size, ), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) + jnp.zeros_like(x) return x + res
def _fast_matrix_shift(x, funnel_factor, is_upsampling=False): """Fast matrix shift. Implements necessary shift for relative positional attention calculations. Based on funnel_factor and information whether we perform upsampling or downsampling it calculates necessary shift and interval at which we pick correct values for attention. Args: x: matrix. funnel_factor: factor to be used for shift. is_upsampling: determines whether perform upsampling. Returns: Shifted matrix x. """ # shift: i-th row is shifted by i * shift elements to the left # k: after shift, we pick every kth element if is_upsampling: k = funnel_factor shift = 1 else: k = 1 shift = funnel_factor bsz, n_head = x.shape[0], x.shape[1] qlen, klen = x.shape[2], (x.shape[3] + 1) // 2 zero_pad = jnp.zeros((bsz, n_head, qlen, shift)) x = jnp.concatenate([zero_pad, x], axis=3) x = x.reshape(bsz, n_head, 2 * klen - 1 + shift, qlen) x = x[:, :, shift:, :] x = x.reshape(bsz, n_head, qlen, klen * 2 - 1) x = x[:, :, :, shift - 1:shift - 1 + klen:k] return x
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer's `filters` value, and the second to last dimension is shrinked if 'VALID' padding is used with kernel_size bigger than one. """ if self._use_bias: if not isinstance(self.weights, (tuple, list)): raise ValueError(f'Weights should be a (w, b) tuple or list; ' f'instead got: {self.weights}') w, b = self.weights else: w = self.weights linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w) # TODO(jaszczur): this could be run after padding for better efficiency if self._kernel_size == 1: # With kernel size 1 we don't have to split or shift anything. linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) else: # We computed a result for every "pixel", but each direction from the # receptive field (there are 'self._kernel_size' such directions) must be # shifted by a different amount. The easiest way to do it is to split # the tensor to 'self._kernel_size' smaller tensors, shift each one # appropriately, and then sum them together. split_shifting_linear_results = jnp.split( linear_results_before_shifting, self._kernel_size, axis=-2) for i in range(self._kernel_size): # Each tensor has to be shifted a different amount. if self._padding == 'WRAP': # We can shift by padding and cutting. With 'wrap' padding we # essentially have a torus. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding, mode='wrap') split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] elif self._padding == 'SAME': # We can shift by padding and cutting. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding) split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] # TODO(jaszczur): improve efficiency by not padding things to cut elif self._padding == 'VALID': # We don't need to shift - just cut the leftmost and rightmost values. cut_left = (self._kernel_size - 1) - i cut_right = split_shifting_linear_results[i].shape[-3] - i split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., cut_left:cut_right, :, :] else: raise ValueError(f'Invalid padding {self._padding}') # After shifting. shifted_linear_results = jnp.concatenate( split_shifting_linear_results, axis=-2) linear_result = jnp.sum(shifted_linear_results, axis=-2) if self._use_bias: return linear_result + b else: return linear_result
def forward(self, xs): return jnp.concatenate(xs, self._axis)
def forward(self, xs): """Executes this layer as part of a forward pass through the model.""" return jnp.concatenate(xs, self._axis)
def test_dot_product_self_attention(target): n_heads = 2 d_head = 3 successful_cases = 0 failed_cases = [] q = jnp.array([[1, 0, 0], [0, 1, 0]]) k = jnp.array([[1, 2, 3], [4, 5, 6]]) v = jnp.array([[0, 1, 0], [1, 0, 1]]) m = jnp.array([[0, 0], [-1e9, 0]]) test_cases = [ { "name": "test dummy tensors", "input": {"q": q[None, :], "k": k[None, :], "v": v[None, :],}, "expected": jnp.array( [[[0.0, 1.0, 0.0], [0.8496746, 0.15032543, 0.8496746]]] ), "error_message": [ "Expected shape does not match", "Expected output does not match.", ], }, { "name": "test dummy tensors", "input": { "q": jnp.array([jnp.concatenate([q, q], axis=-1)]), "k": jnp.array([jnp.concatenate([k, k], axis=-1)]), "v": jnp.array([jnp.concatenate([v, v], axis=-1)]), }, "expected": jnp.array( [ [ [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], [ 0.9205239, 0.07947586, 0.9205239, 0.9205239, 0.07947586, 0.9205239, ], ] ] ), "error_message": [ "Expected shape does not match", "Expected output does not match.", ], }, ] for test_case in test_cases: name = test_case.get("name") input_dict = test_case.get("input") expected = test_case.get("expected") output = target(**input_dict) try: assert output.shape == expected.shape successful_cases += 1 except: print(test_case.get("error")[0]) failed_cases.append( { "name": test_case["name"], "expected": test_case["expected"].shape, "got": output.shape, } ) try: assert jnp.isclose(output, expected).all() successful_cases += 1 except: print(test_case.get("error")[1]) failed_cases.append( { "name": test_case["name"], "expected": test_case["expected"], "got": output, } ) if len(failed_cases) == 0: print("\033[92m All tests passed") else: print("\033[92m", successful_cases, " Tests passed") print("\033[91m", len(failed_cases), " Tests failed")
def test_compute_attention_heads_closure(target): n_heads = 2 d_head = 3 successful_cases = 0 failed_cases = [] q = jnp.array([[1, 0, 0], [0, 1, 0]]) test_cases = [ { "name": "test dummy tensors", "input": { "x": jnp.array( [jnp.concatenate([q, q], axis=-1), jnp.concatenate([q, q], axis=-1)] ) }, "expected": jnp.array( [ [[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 1, 0]], ] ), "error_message": [ "Expected shape does not match", "Expected output does not match.", ], }, { "name": "test dummy tensors", "input": { "x": jnp.array( [ jnp.concatenate([q, q], axis=-1), jnp.concatenate([q, q], axis=-1), jnp.concatenate([q, q], axis=-1), ] ) }, "expected": jnp.array( [ [[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 1, 0]], ] ), "error_message": [ "Expected shape does not match", "Expected output does not match.", ], }, ] for test_case in test_cases: name = test_case.get("name") input_dict = test_case.get("input") expected = test_case.get("expected") output = target(n_heads=n_heads, d_head=d_head)(**input_dict) try: assert output.shape == expected.shape successful_cases += 1 except: print(test_case.get("error")[0]) failed_cases.append( { "name": test_case["name"], "expected": test_case["expected"].shape, "got": output.shape, } ) try: assert jnp.isclose(output, expected).all() successful_cases += 1 except: print(test_case.get("error")[1]) failed_cases.append( { "name": test_case["name"], "expected": test_case["expected"], "got": output, } ) if len(failed_cases) == 0: print("\033[92m All tests passed") else: print("\033[92m", successful_cases, " Tests passed") print("\033[91m", len(failed_cases), " Tests failed")
def _unshard_fn(x): y = jax.lax.all_gather(x, 'batch', axis_index_groups=groups) split_y = jnp.split(y, n_shards, axis=0) split_y = [jnp.squeeze(sy, axis=0) for sy in split_y] axis = _axis_to_shard_heuristic(split_y[0].shape) return jnp.concatenate(split_y, axis=axis)
def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return jnp.concatenate((-x2, x1), axis=x1.ndim - 1)
def Sinusoidal_Embeddings(positions): inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb
def pad_borders(v): total_len = v.shape[2] pre, mid, post = split_along_l(v, chunk_offset, total_len - last_chunk_len, total_len) pre, post = map(pad_to_chunk_len, [pre, post]) return jnp.concatenate([pre, mid, post], axis=2)