def _optax_gradient_transformation(self, hps): return optax_distributed_shampoo.distributed_shampoo( learning_rate=hps.learning_rate, block_size=hps.block_size, beta1=hps.beta1, beta2=hps.beta2, diagonal_epsilon=hps.diagonal_epsilon, matrix_epsilon=hps.matrix_epsilon, weight_decay=hps.weight_decay, start_preconditioning_step=hps.start_preconditioning_step, preconditioning_compute_steps=hps.preconditioning_compute_steps, statistics_compute_steps=hps.statistics_compute_steps, best_effort_shape_interpretation=hps. best_effort_shape_interpretation, graft_type=hps.graft_type, nesterov=hps.nesterov, exponent_override=hps.exponent_override, batch_axis_name=hps.batch_axis_name, statistics_partition_spec=hps.statistics_partition_spec, preconditioner_partition_spec=hps.preconditioner_partition_spec, num_devices_for_pjit=hps.num_devices_for_pjit, shard_optimizer_states=hps.shard_optimizer_states, best_effort_memory_usage_reduction=hps. best_effort_memory_usage_reduction, inverse_failure_threshold=hps.inverse_failure_threshold, moving_average_for_momentum=hps.moving_average_for_momentum, skip_preconditioning_dim_size_gt=hps. skip_preconditioning_dim_size_gt, clip_by_scaled_gradient_norm=hps.clip_by_scaled_gradient_norm, precision=hps.precision)
def test_distributed_shampoo( self, best_effort_memory_usage_reduction, symmetric_block_size, block_statistics, ): params = self.init_params optim = distributed_shampoo.distributed_shampoo( 0.1, 32, batch_axis_name='batch', preconditioning_compute_steps=2, best_effort_memory_usage_reduction= best_effort_memory_usage_reduction, ) init_fn = self.variant(optim.init) transform_fn = self.variant(optim.update) def _update(unused_batch): return transform_fn(self.per_step_updates, state, params) state = init_fn(params) chex.assert_tree_all_finite(state) pmap_fn = jax.pmap(_update, axis_name='batch') updates, state = pmap_fn(jnp.array([1.0])) chex.assert_tree_all_finite((params, updates, state))
def test_distributed_shampoo_no_pmap(self): params = self.init_params optim = distributed_shampoo.distributed_shampoo( 0.1, 32, batch_axis_name=None, preconditioning_compute_steps=2, ) init_fn = self.variant(optim.init) transform_fn = self.variant(optim.update) state = init_fn(params) chex.assert_tree_all_finite(state) updates, state = transform_fn(self.per_step_updates, state, params) chex.assert_tree_all_finite((params, updates, state))
def test_distributed_shampoo(self): params = self.init_params optim = distributed_shampoo.distributed_shampoo( 0.1, 32, batch_axis_name='batch', preconditioning_compute_steps=2) init_fn = self.variant(optim.init) transform_fn = self.variant(optim.update) def _update(unused_batch): return transform_fn(self.per_step_updates, state, params) state = init_fn(params) chex.assert_tree_all_finite(state) pmap_fn = jax.pmap(_update, axis_name='batch') updates, state = pmap_fn(jnp.array([1.0])) chex.assert_tree_all_finite((params, updates, state))