Esempio n. 1
0
 def test_sparse_ff_predict_equals_eval(self):
     d_model = 1024
     n_elements_in_block = 64
     d_ff = d_model * 8
     x_shape = (1, 1, d_model)
     temperature = 0.7
     with fastmath.use_backend(fastmath.Backend.JAX):
         x = np.ones(x_shape).astype(np.float32)
         input_signature = shapes.signature(x)
         common_kwargs = dict(
             d_ff=d_ff,
             n_elements_in_block=n_elements_in_block,
             temperature=temperature,
         )
         eval_model = sparsity.SparseFF(mode='eval', **common_kwargs)
         weights, state = eval_model.init(input_signature)
         eval_out, _ = eval_model.pure_fn(x,
                                          weights,
                                          state,
                                          rng=jax.random.PRNGKey(0))
         pred_model = sparsity.SparseFF(mode='predict', **common_kwargs)
         _, _ = pred_model.init(input_signature)
         pred_out, _ = pred_model.pure_fn(x,
                                          weights,
                                          state,
                                          rng=jax.random.PRNGKey(0))
         self.assertEqual(eval_out.shape, x.shape)
         # eval_out and pred_out should be identical.
         np.testing.assert_array_almost_equal(eval_out[0, 0, :],
                                              pred_out[0, 0, :])
Esempio n. 2
0
 def test_sparse_ff_with_chunking(self, mode):
   d_model = 8
   n_elements_in_block = 2
   d_ff = 16
   x_shape = (2, 8, d_model)
   temperature = 0.7
   with fastmath.use_backend(fastmath.Backend.JAX):
     x = np.ones(x_shape).astype(np.float32)
     input_signature = shapes.signature(x)
     model = sparsity.SparseFF(
         d_ff=d_ff,
         n_elements_in_block=n_elements_in_block,
         temperature=temperature,
         ff_chunk_size=4,
         mode=mode)
     weights, state = model.init(input_signature)
     out, _ = model.pure_fn(
         x, weights, state, rng=jax.random.PRNGKey(0))
     self.assertEqual(out.shape, x.shape)