예제 #1
0
def test_prrs_share_tensor() -> None:
    """Test przs_generate_random_share method from Session for ShareTensor."""
    session = Session()  # default protocol: FSS
    SessionManager.setup_mpc(session)
    seed1 = secrets.randbits(32)
    seed2 = secrets.randbits(32)
    gen1 = get_new_generator(seed1)
    gen2 = get_new_generator(seed2)
    session.przs_generators = [gen1, gen2]
    shape = (2, 1)
    share = session.prrs_generate_random_share(shape=shape)
    assert isinstance(share, ShareTensor)

    new_gen1 = get_new_generator(seed1)
    share1 = generate_random_element(generator=new_gen1,
                                     shape=shape,
                                     tensor_type=session.tensor_type)
    target_tensor = share1
    assert (share.tensor == target_tensor).all()
예제 #2
0
def test_prrs_rs_tensor() -> None:
    """Test przs_generate_random_share method from Session for ReplicatedSharedTensor."""
    falcon = Falcon(security_type="malicious")
    session = Session(protocol=falcon)
    SessionManager.setup_mpc(session)
    seed1 = secrets.randbits(32)
    seed2 = secrets.randbits(32)
    gen1 = get_new_generator(seed1)
    gen2 = get_new_generator(seed2)
    session.przs_generators = [gen1, gen2]
    shape = (2, 1)
    share = session.prrs_generate_random_share(shape=shape)
    assert isinstance(share, ReplicatedSharedTensor)

    new_gen1 = get_new_generator(seed1)
    new_gen2 = get_new_generator(seed2)
    share1 = generate_random_element(generator=new_gen1,
                                     shape=shape,
                                     tensor_type=session.tensor_type)
    share2 = generate_random_element(generator=new_gen2,
                                     shape=shape,
                                     tensor_type=session.tensor_type)
    target_tensor = [share1, share2]
    assert (torch.cat(share.shares) == torch.cat(target_tensor)).all()