def test_matrix_inverse_root_padding(self, sz):
        """Test padding does not affect result much."""

        # Note sz == 1 case will not pass tests here b/c the method
        # is exact for scalars (but padding triggers coupled iteration).

        condition_number = 1e3
        ms = self._gen_symmetrix_matrix(sz,
                                        condition_number).astype(np.float32)

        # Shift matrix norm down by some large factor, so that improper padding
        # handling results in an error by increasing the condition number.
        ms = jnp.array(ms) * 1e-3

        rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
            ms, 4, ridge_epsilon=1e-3)
        err = metrics.inverse_pth_root_errors
        pad_ms = distributed_shampoo.pad_square_matrix(ms, sz * 2)
        pad_rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
            pad_ms, 4, ridge_epsilon=1e-3, padding_start=sz)
        pad_err = metrics.inverse_pth_root_errors
        pad_rt_principal = pad_rt[:sz, :sz]
        np.testing.assert_allclose(
            rt,
            pad_rt_principal,
            # The fact that this is so large keeps vladf up at night,
            # but without padding_start argument it's even worse (>1).
            rtol=1e-2,
            err_msg=np.array2string(rt - pad_rt_principal))
        self.assertLessEqual(pad_err, 4 * err)
        self.assertEqual(np.abs(pad_rt[sz:]).sum(), 0)
        self.assertEqual(np.abs(pad_rt[:, sz:]).sum(), 0)
 def test_all_padding(self):
     """Test full padding matrix."""
     empty = jnp.zeros([0, 0])
     padded = distributed_shampoo.pad_square_matrix(empty, 10)
     rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
         padded, 4, ridge_epsilon=1e-3, padding_start=0)
     err = metrics.inverse_pth_root_errors
     self.assertEqual(np.abs(rt).sum(), 0.0)
     self.assertEqual(np.abs(err).sum(), 0.0)
 def test_pad_square_matrix_error(self, shape, max_size):
     with self.assertRaises(ValueError):
         distributed_shampoo.pad_square_matrix(mat=jnp.ones(shape=shape),
                                               max_size=max_size)
 def test_pad_square_matrix(self, max_size, result):
     self.assertAllClose(
         distributed_shampoo.pad_square_matrix(mat=jnp.ones(
             shape=(3, 3), dtype=jnp.float32),
                                               max_size=max_size),
         jnp.asarray(result, dtype=jnp.float32))