def init_generators(self, seed_current: int, seed_next: int) -> None: """Initialize the generators - that are used for Pseudo Random Zero Shares. Args: seed_current (int): the seed for our party seed_next (int): thee seed for the next party """ generator_current = get_new_generator(seed_current) generator_next = get_new_generator(seed_next) self.przs_generators = [generator_current, generator_next]
def test_przs_generate_random_share(get_clients) -> None: """Test przs_generate_random_share method from Session.""" session = Session() SessionManager.setup_mpc(session) gen1 = get_new_generator(42) gen2 = get_new_generator(43) session.przs_generators = [gen1, gen2] share = session.przs_generate_random_share(shape=(2, 1)) assert isinstance(share, ShareTensor) target_tensor = torch.tensor( ([-1540733531777602634], [2813554787685566880])) assert (share.tensor == target_tensor).all()
def test_prrs_share_tensor() -> None: """Test przs_generate_random_share method from Session for ShareTensor.""" session = Session() # default protocol: FSS SessionManager.setup_mpc(session) seed1 = secrets.randbits(32) seed2 = secrets.randbits(32) gen1 = get_new_generator(seed1) gen2 = get_new_generator(seed2) session.przs_generators = [gen1, gen2] shape = (2, 1) share = session.prrs_generate_random_share(shape=shape) assert isinstance(share, ShareTensor) new_gen1 = get_new_generator(seed1) share1 = generate_random_element(generator=new_gen1, shape=shape, tensor_type=session.tensor_type) target_tensor = share1 assert (share.tensor == target_tensor).all()
def test_przs_rs_tensor() -> None: """Test przs_generate_random_share method from Session for ReplicatedSharedTensor.""" falcon = Falcon(security_type="malicious") session = Session(protocol=falcon) SessionManager.setup_mpc(session) seed1 = secrets.randbits(32) seed2 = secrets.randbits(32) gen1 = get_new_generator(seed1) gen2 = get_new_generator(seed2) session.przs_generators = [gen1, gen2] shape = (2, 1) share = session.przs_generate_random_share(shape=shape) assert isinstance(share, ReplicatedSharedTensor) new_gen1 = get_new_generator(seed1) new_gen2 = get_new_generator(seed2) share1 = generate_random_element(generator=new_gen1, shape=shape, tensor_type=session.tensor_type) share2 = generate_random_element(generator=new_gen2, shape=shape, tensor_type=session.tensor_type) target_tensor = share1 - share2 assert (share.shares[0] == target_tensor).all()