def test_qr_mpc(hook, workers):
    """
    Testing QR decomposition with an AdditiveSharedTensor
    """
    bob = workers["bob"]
    alice = workers["alice"]
    crypto_prov = workers["james"]

    torch.manual_seed(
        0)  # Truncation might not always work so we set the random seed
    n_cols = 3
    n_rows = 3
    t = torch.randn([n_rows, n_cols])
    t_sh = t.fix_precision(precision_fractional=6).share(
        bob, alice, crypto_provider=crypto_prov)

    Q, R = qr(t_sh, norm_factor=3**(1 / 2), mode="complete")
    Q = Q.get().float_precision()
    R = R.get().float_precision()

    # Check if Q is orthogonal
    I = Q @ Q.t()
    assert ((torch.eye(n_rows) - I).abs() < 1e-2).all()

    # Check if R is upper triangular matrix
    for col in range(n_cols - 1):
        assert ((R[col + 1:, col]).abs() < 1e-2).all()

    # Check if QR == t
    assert ((Q @ R - t).abs() < 1e-2).all()
def test_qr(hook, workers):
    """
    Testing QR decomposition with remote matrix
    """
    torch.manual_seed(
        42)  # Truncation might not always work so we set the random seed

    bob = workers["bob"]
    n_cols = 5
    n_rows = 10
    t = torch.randn([n_rows, n_cols])
    Q, R = qr(t.send(bob), mode="complete")
    Q = Q.get()
    R = R.get()

    # Check if Q is orthogonal
    I = Q @ Q.t()
    assert ((torch.eye(n_rows) - I).abs() < 1e-5).all()

    # Check if R is upper triangular matrix
    for col in range(n_cols):
        assert ((R[col + 1:, col]).abs() < 1e-5).all()

    # Check if QR == t
    assert ((Q @ R - t).abs() < 1e-5).all()

    # test modes
    Q, R = qr(t.send(bob), mode="reduced")
    assert Q.shape == (n_rows, n_cols)
    assert R.shape == (n_cols, n_cols)

    Q, R = qr(t.send(bob), mode="complete")
    assert Q.shape == (n_rows, n_rows)
    assert R.shape == (n_rows, n_cols)

    R = qr(t.send(bob), mode="r")
    assert R.shape == (n_cols, n_cols)