Beispiel #1
0
def test_przs_generate_random_share(get_clients) -> None:
    """Test przs_generate_random_share method from Session."""
    session = Session()
    SessionManager.setup_mpc(session)
    gen1 = get_new_generator(42)
    gen2 = get_new_generator(43)
    session.przs_generators = [gen1, gen2]
    share = session.przs_generate_random_share(shape=(2, 1))
    assert isinstance(share, ShareTensor)
    target_tensor = torch.tensor(
        ([-1540733531777602634], [2813554787685566880]))
    assert (share.tensor == target_tensor).all()
Beispiel #2
0
def test_prrs_share_tensor() -> None:
    """Test przs_generate_random_share method from Session for ShareTensor."""
    session = Session()  # default protocol: FSS
    SessionManager.setup_mpc(session)
    seed1 = secrets.randbits(32)
    seed2 = secrets.randbits(32)
    gen1 = get_new_generator(seed1)
    gen2 = get_new_generator(seed2)
    session.przs_generators = [gen1, gen2]
    shape = (2, 1)
    share = session.prrs_generate_random_share(shape=shape)
    assert isinstance(share, ShareTensor)

    new_gen1 = get_new_generator(seed1)
    share1 = generate_random_element(generator=new_gen1,
                                     shape=shape,
                                     tensor_type=session.tensor_type)
    target_tensor = share1
    assert (share.tensor == target_tensor).all()
Beispiel #3
0
def test_przs_rs_tensor() -> None:
    """Test przs_generate_random_share method from Session for ReplicatedSharedTensor."""
    falcon = Falcon(security_type="malicious")
    session = Session(protocol=falcon)
    SessionManager.setup_mpc(session)
    seed1 = secrets.randbits(32)
    seed2 = secrets.randbits(32)
    gen1 = get_new_generator(seed1)
    gen2 = get_new_generator(seed2)
    session.przs_generators = [gen1, gen2]
    shape = (2, 1)
    share = session.przs_generate_random_share(shape=shape)
    assert isinstance(share, ReplicatedSharedTensor)

    new_gen1 = get_new_generator(seed1)
    new_gen2 = get_new_generator(seed2)
    share1 = generate_random_element(generator=new_gen1,
                                     shape=shape,
                                     tensor_type=session.tensor_type)
    share2 = generate_random_element(generator=new_gen2,
                                     shape=shape,
                                     tensor_type=session.tensor_type)
    target_tensor = share1 - share2
    assert (share.shares[0] == target_tensor).all()