def _optax_gradient_transformation(self, hps):
     return optax_distributed_shampoo.distributed_shampoo(
         learning_rate=hps.learning_rate,
         block_size=hps.block_size,
         beta1=hps.beta1,
         beta2=hps.beta2,
         diagonal_epsilon=hps.diagonal_epsilon,
         matrix_epsilon=hps.matrix_epsilon,
         weight_decay=hps.weight_decay,
         start_preconditioning_step=hps.start_preconditioning_step,
         preconditioning_compute_steps=hps.preconditioning_compute_steps,
         statistics_compute_steps=hps.statistics_compute_steps,
         best_effort_shape_interpretation=hps.
         best_effort_shape_interpretation,
         graft_type=hps.graft_type,
         nesterov=hps.nesterov,
         exponent_override=hps.exponent_override,
         batch_axis_name=hps.batch_axis_name,
         statistics_partition_spec=hps.statistics_partition_spec,
         preconditioner_partition_spec=hps.preconditioner_partition_spec,
         num_devices_for_pjit=hps.num_devices_for_pjit,
         shard_optimizer_states=hps.shard_optimizer_states,
         best_effort_memory_usage_reduction=hps.
         best_effort_memory_usage_reduction,
         inverse_failure_threshold=hps.inverse_failure_threshold,
         moving_average_for_momentum=hps.moving_average_for_momentum,
         skip_preconditioning_dim_size_gt=hps.
         skip_preconditioning_dim_size_gt,
         clip_by_scaled_gradient_norm=hps.clip_by_scaled_gradient_norm,
         precision=hps.precision)
    def test_distributed_shampoo(
        self,
        best_effort_memory_usage_reduction,
        symmetric_block_size,
        block_statistics,
    ):
        params = self.init_params

        optim = distributed_shampoo.distributed_shampoo(
            0.1,
            32,
            batch_axis_name='batch',
            preconditioning_compute_steps=2,
            best_effort_memory_usage_reduction=
            best_effort_memory_usage_reduction,
        )
        init_fn = self.variant(optim.init)
        transform_fn = self.variant(optim.update)

        def _update(unused_batch):
            return transform_fn(self.per_step_updates, state, params)

        state = init_fn(params)
        chex.assert_tree_all_finite(state)
        pmap_fn = jax.pmap(_update, axis_name='batch')

        updates, state = pmap_fn(jnp.array([1.0]))
        chex.assert_tree_all_finite((params, updates, state))
    def test_distributed_shampoo_no_pmap(self):
        params = self.init_params

        optim = distributed_shampoo.distributed_shampoo(
            0.1,
            32,
            batch_axis_name=None,
            preconditioning_compute_steps=2,
        )
        init_fn = self.variant(optim.init)
        transform_fn = self.variant(optim.update)
        state = init_fn(params)
        chex.assert_tree_all_finite(state)
        updates, state = transform_fn(self.per_step_updates, state, params)
        chex.assert_tree_all_finite((params, updates, state))
Example #4
0
    def test_distributed_shampoo(self):
        params = self.init_params

        optim = distributed_shampoo.distributed_shampoo(
            0.1, 32, batch_axis_name='batch', preconditioning_compute_steps=2)
        init_fn = self.variant(optim.init)
        transform_fn = self.variant(optim.update)

        def _update(unused_batch):
            return transform_fn(self.per_step_updates, state, params)

        state = init_fn(params)
        chex.assert_tree_all_finite(state)
        pmap_fn = jax.pmap(_update, axis_name='batch')

        updates, state = pmap_fn(jnp.array([1.0]))
        chex.assert_tree_all_finite((params, updates, state))