def test_matrix_inverse_root_padding(self, sz): """Test padding does not affect result much.""" # Note sz == 1 case will not pass tests here b/c the method # is exact for scalars (but padding triggers coupled iteration). condition_number = 1e3 ms = self._gen_symmetrix_matrix(sz, condition_number).astype(np.float32) # Shift matrix norm down by some large factor, so that improper padding # handling results in an error by increasing the condition number. ms = jnp.array(ms) * 1e-3 rt, metrics = distributed_shampoo.matrix_inverse_pth_root( ms, 4, ridge_epsilon=1e-3) err = metrics.inverse_pth_root_errors pad_ms = distributed_shampoo.pad_square_matrix(ms, sz * 2) pad_rt, metrics = distributed_shampoo.matrix_inverse_pth_root( pad_ms, 4, ridge_epsilon=1e-3, padding_start=sz) pad_err = metrics.inverse_pth_root_errors pad_rt_principal = pad_rt[:sz, :sz] np.testing.assert_allclose( rt, pad_rt_principal, # The fact that this is so large keeps vladf up at night, # but without padding_start argument it's even worse (>1). rtol=1e-2, err_msg=np.array2string(rt - pad_rt_principal)) self.assertLessEqual(pad_err, 4 * err) self.assertEqual(np.abs(pad_rt[sz:]).sum(), 0) self.assertEqual(np.abs(pad_rt[:, sz:]).sum(), 0)
def test_all_padding(self): """Test full padding matrix.""" empty = jnp.zeros([0, 0]) padded = distributed_shampoo.pad_square_matrix(empty, 10) rt, metrics = distributed_shampoo.matrix_inverse_pth_root( padded, 4, ridge_epsilon=1e-3, padding_start=0) err = metrics.inverse_pth_root_errors self.assertEqual(np.abs(rt).sum(), 0.0) self.assertEqual(np.abs(err).sum(), 0.0)
def test_pad_square_matrix_error(self, shape, max_size): with self.assertRaises(ValueError): distributed_shampoo.pad_square_matrix(mat=jnp.ones(shape=shape), max_size=max_size)
def test_pad_square_matrix(self, max_size, result): self.assertAllClose( distributed_shampoo.pad_square_matrix(mat=jnp.ones( shape=(3, 3), dtype=jnp.float32), max_size=max_size), jnp.asarray(result, dtype=jnp.float32))