def test_qr_mpc(hook, workers): """ Testing QR decomposition with an AdditiveSharedTensor """ bob = workers["bob"] alice = workers["alice"] crypto_prov = workers["james"] torch.manual_seed( 0) # Truncation might not always work so we set the random seed n_cols = 3 n_rows = 3 t = torch.randn([n_rows, n_cols]) t_sh = t.fix_precision(precision_fractional=6).share( bob, alice, crypto_provider=crypto_prov) Q, R = qr(t_sh, norm_factor=3**(1 / 2), mode="complete") Q = Q.get().float_precision() R = R.get().float_precision() # Check if Q is orthogonal I = Q @ Q.t() assert ((torch.eye(n_rows) - I).abs() < 1e-2).all() # Check if R is upper triangular matrix for col in range(n_cols - 1): assert ((R[col + 1:, col]).abs() < 1e-2).all() # Check if QR == t assert ((Q @ R - t).abs() < 1e-2).all()
def test_qr(hook, workers): """ Testing QR decomposition with remote matrix """ torch.manual_seed( 42) # Truncation might not always work so we set the random seed bob = workers["bob"] n_cols = 5 n_rows = 10 t = torch.randn([n_rows, n_cols]) Q, R = qr(t.send(bob), mode="complete") Q = Q.get() R = R.get() # Check if Q is orthogonal I = Q @ Q.t() assert ((torch.eye(n_rows) - I).abs() < 1e-5).all() # Check if R is upper triangular matrix for col in range(n_cols): assert ((R[col + 1:, col]).abs() < 1e-5).all() # Check if QR == t assert ((Q @ R - t).abs() < 1e-5).all() # test modes Q, R = qr(t.send(bob), mode="reduced") assert Q.shape == (n_rows, n_cols) assert R.shape == (n_cols, n_cols) Q, R = qr(t.send(bob), mode="complete") assert Q.shape == (n_rows, n_rows) assert R.shape == (n_rows, n_cols) R = qr(t.send(bob), mode="r") assert R.shape == (n_cols, n_cols)