Exemplo n.º 1
0
 def test_blocksparse_ff_predict_equals_eval(self):
     d_model = 1024
     num_experts = 64
     d_ff = d_model * 8
     x_shape = (1, 1, d_model)
     temperature = 0.7
     with fastmath.use_backend(fastmath.Backend.JAX):
         x = np.ones(x_shape).astype(np.float32)
         input_signature = shapes.signature(x)
         common_kwargs = dict(
             d_ff=d_ff,
             num_experts=num_experts,
             temperature=temperature,
         )
         eval_model = sparsity.BlockSparseFF(mode='eval', **common_kwargs)
         weights, state = eval_model.init(input_signature)
         eval_out, _ = eval_model.pure_fn(x,
                                          weights,
                                          state,
                                          rng=jax.random.PRNGKey(0))
         pred_model = sparsity.BlockSparseFF(mode='predict',
                                             **common_kwargs)
         _, _ = pred_model.init(input_signature)
         pred_out, _ = pred_model.pure_fn(x,
                                          weights,
                                          state,
                                          rng=jax.random.PRNGKey(0))
         self.assertEqual(eval_out.shape, x.shape)
         # eval_out and pred_out should be identical.
         np.testing.assert_array_almost_equal(eval_out[0, 0, :],
                                              pred_out[0, 0, :])
Exemplo n.º 2
0
 def test_blocksparse_ff_train(self):
   d_model = 1024
   num_experts = 64
   d_ff = d_model * 8
   x_shape = (3, 7, d_model)
   with fastmath.use_backend(fastmath.Backend.JAX):
     layer = sparsity.BlockSparseFF(
         d_ff=d_ff, num_experts=num_experts, temperature=0.7, mode='train')
     x = np.ones(x_shape).astype(np.float32)
     _, _ = layer.init(shapes.signature(x))
     y = layer(x)
     self.assertEqual(y.shape, x.shape)