def test_sparse_ff_predict_equals_eval(self): d_model = 1024 n_elements_in_block = 64 d_ff = d_model * 8 x_shape = (1, 1, d_model) temperature = 0.7 with fastmath.use_backend(fastmath.Backend.JAX): x = np.ones(x_shape).astype(np.float32) input_signature = shapes.signature(x) common_kwargs = dict( d_ff=d_ff, n_elements_in_block=n_elements_in_block, temperature=temperature, ) eval_model = sparsity.SparseFF(mode='eval', **common_kwargs) weights, state = eval_model.init(input_signature) eval_out, _ = eval_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) pred_model = sparsity.SparseFF(mode='predict', **common_kwargs) _, _ = pred_model.init(input_signature) pred_out, _ = pred_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) self.assertEqual(eval_out.shape, x.shape) # eval_out and pred_out should be identical. np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :])
def test_sparse_ff_with_chunking(self, mode): d_model = 8 n_elements_in_block = 2 d_ff = 16 x_shape = (2, 8, d_model) temperature = 0.7 with fastmath.use_backend(fastmath.Backend.JAX): x = np.ones(x_shape).astype(np.float32) input_signature = shapes.signature(x) model = sparsity.SparseFF( d_ff=d_ff, n_elements_in_block=n_elements_in_block, temperature=temperature, ff_chunk_size=4, mode=mode) weights, state = model.init(input_signature) out, _ = model.pure_fn( x, weights, state, rng=jax.random.PRNGKey(0)) self.assertEqual(out.shape, x.shape)