def test_blocksparse_ff_predict_equals_eval(self): d_model = 1024 num_experts = 64 d_ff = d_model * 8 x_shape = (1, 1, d_model) temperature = 0.7 with fastmath.use_backend(fastmath.Backend.JAX): x = np.ones(x_shape).astype(np.float32) input_signature = shapes.signature(x) common_kwargs = dict( d_ff=d_ff, num_experts=num_experts, temperature=temperature, ) eval_model = sparsity.BlockSparseFF(mode='eval', **common_kwargs) weights, state = eval_model.init(input_signature) eval_out, _ = eval_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) pred_model = sparsity.BlockSparseFF(mode='predict', **common_kwargs) _, _ = pred_model.init(input_signature) pred_out, _ = pred_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) self.assertEqual(eval_out.shape, x.shape) # eval_out and pred_out should be identical. np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :])
def test_blocksparse_ff_train(self): d_model = 1024 num_experts = 64 d_ff = d_model * 8 x_shape = (3, 7, d_model) with fastmath.use_backend(fastmath.Backend.JAX): layer = sparsity.BlockSparseFF( d_ff=d_ff, num_experts=num_experts, temperature=0.7, mode='train') x = np.ones(x_shape).astype(np.float32) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, x.shape)