def testLossAndGradientsAreFinite(self): # Test that the loss and its approximation both give finite losses and # derivatives everywhere that they should for a wide range of values. num_samples = 100000 rng = random.PRNGKey(0) # Normally distributed inputs. rng, key = random.split(rng) x = random.normal(key, shape=[num_samples]) # Uniformly distributed values in (-16, 3), quantized to the nearest 0.1 # to ensure that we hit the special cases at 0, 2. rng, key = random.split(rng) alpha = jnp.round( random.uniform(key, shape=[num_samples], minval=-16, maxval=3) * 10) / 10. # Random log-normally distributed values in approx (1e-5, 100000): rng, key = random.split(rng) scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5 fn = self.variant(general.lossfun) loss = fn(x, alpha, scale) d_x, d_alpha, d_scale = (jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)), [0, 1, 2])(x, alpha, scale)) for v in [loss, d_x, d_alpha, d_scale]: chex.assert_tree_all_finite(v)
def test_distributed_shampoo( self, best_effort_memory_usage_reduction, symmetric_block_size, block_statistics, ): params = self.init_params optim = distributed_shampoo.distributed_shampoo( 0.1, 32, batch_axis_name='batch', preconditioning_compute_steps=2, best_effort_memory_usage_reduction= best_effort_memory_usage_reduction, ) init_fn = self.variant(optim.init) transform_fn = self.variant(optim.update) def _update(unused_batch): return transform_fn(self.per_step_updates, state, params) state = init_fn(params) chex.assert_tree_all_finite(state) pmap_fn = jax.pmap(_update, axis_name='batch') updates, state = pmap_fn(jnp.array([1.0])) chex.assert_tree_all_finite((params, updates, state))
def test_scalers(constructor: Callable[[], GradientTransformation[Any, Any]]) -> None: params = init_params scaler = constructor() init_fn = variant(scaler.init) transform_fn = variant(scaler.update) state = init_fn(params) chex.assert_tree_all_finite(state) updates, state = transform_fn(per_step_updates, state, params) chex.assert_tree_all_finite((params, updates, state)) tree_map(lambda *args: chex.assert_equal_shape(args), params, updates)
def test_scale_by_factored_rms(self): params = self.init_params scaler = factorized.scale_by_factored_rms() init_fn = self.variant(scaler.init) transform_fn = self.variant(scaler.update) state = init_fn(params) chex.assert_tree_all_finite(state) updates, state = transform_fn(self.per_step_updates, state, params) chex.assert_tree_all_finite((params, updates, state)) chex.assert_tree_all_equal_shapes(params, updates)
def testGradientsAreFiniteWithAllInputs(self, alpha): x_half = jnp.concatenate( [jnp.exp(jnp.linspace(-80, 80, 1001)), jnp.array([jnp.inf])]) x = jnp.concatenate([-x_half[::-1], jnp.array([0.]), x_half]) scale = jnp.full_like(x, 1.) fn = self.variant(lambda x, s: general.lossfun(x, alpha, s)) loss = fn(x, scale) d_x, d_scale = jax.vmap(jax.grad(fn, [0, 1]))(x, scale) for v in [loss, d_x, d_scale]: chex.assert_tree_all_finite(v)
def test_scalers(self, scaler_constr): params = self.init_params scaler = scaler_constr() init_fn = self.variant(scaler.init) transform_fn = self.variant(scaler.update) state = init_fn(params) chex.assert_tree_all_finite(state) updates, state = transform_fn(self.per_step_updates, state, params) chex.assert_tree_all_finite((params, updates, state)) jax.tree_multimap(lambda *args: chex.assert_equal_shape(args), params, updates)
def test_distributed_shampoo_no_pmap(self): params = self.init_params optim = distributed_shampoo.distributed_shampoo( 0.1, 32, batch_axis_name=None, preconditioning_compute_steps=2, ) init_fn = self.variant(optim.init) transform_fn = self.variant(optim.update) state = init_fn(params) chex.assert_tree_all_finite(state) updates, state = transform_fn(self.per_step_updates, state, params) chex.assert_tree_all_finite((params, updates, state))
def test_sm3_basic(self): params = self.init_params optim = sm3.sm3(0.1, 0.9, 0.999) init_fn = self.variant(optim.init) transform_fn = self.variant(optim.update) def _update(unused_batch): return transform_fn(self.per_step_updates, state, params) state = init_fn(params) chex.assert_tree_all_finite(state) pmap_fn = jax.pmap(_update, axis_name='batch') updates, state = pmap_fn(jnp.array([1.0])) chex.assert_tree_all_finite((params, updates, state))
def test_distributed_shampoo(self): params = self.init_params optim = distributed_shampoo.distributed_shampoo( 0.1, 32, batch_axis_name='batch', preconditioning_compute_steps=2) init_fn = self.variant(optim.init) transform_fn = self.variant(optim.update) def _update(unused_batch): return transform_fn(self.per_step_updates, state, params) state = init_fn(params) chex.assert_tree_all_finite(state) pmap_fn = jax.pmap(_update, axis_name='batch') updates, state = pmap_fn(jnp.array([1.0])) chex.assert_tree_all_finite((params, updates, state))