示例#1
0
def test_unpack_typerror_crypten_model():
    dummy_input = th.rand(1, 28 * 28)
    expected_crypten_model = crypten.nn.from_pytorch(ExampleNet(), dummy_input)
    packed = utils.pack_values(expected_crypten_model)

    with pytest.raises(TypeError):
        utils.unpack_values(packed)
示例#2
0
def test_pack_crypten_model():
    class ExampleNet(th.nn.Module):
        def __init__(self):
            super(ExampleNet, self).__init__()
            self.fc = th.nn.Linear(28 * 28, 2)

        def forward(self, x):
            out = self.fc(x)
            return out

    dummy_input = th.rand(1, 28 * 28)
    expected_crypten_model = crypten.nn.from_pytorch(ExampleNet(), dummy_input)
    expected_out = expected_crypten_model(dummy_input)

    packed = utils.pack_values(expected_crypten_model)

    # zero all model's parameters
    with th.no_grad():
        for p in expected_crypten_model.parameters():
            assert isinstance(p, th.Tensor)
            p.set_(th.zeros_like(p))

    crypten_model = utils.unpack_values(packed, model=expected_crypten_model)

    out = crypten_model(dummy_input)
    assert th.all(expected_out == out)
示例#3
0
def test_run_party():
    expected = th.tensor(5)

    def party():  # pragma: no cover
        t = crypten.cryptensor(expected)
        return t.get_plain_text()

    t = run_party(None, party, 0, 1, "127.0.0.1", 15463, (), {})
    result = utils.unpack_values(t)
    assert result == expected
示例#4
0
def test_pack_tensors(tensors):
    packed = utils.pack_values(tensors)
    unpacked = utils.unpack_values(packed)

    if isinstance(unpacked, tuple):  # return tensor1, tensor2 ...
        assert len(tensors) == len(unpacked)
        for unpacked_tensor, tensor in zip(unpacked, tensors):
            assert th.all(unpacked_tensor == tensor)

    else:  # return tensor
        assert th.all(unpacked == tensors)
示例#5
0
def _send_party_info(worker, rank, msg, return_values, model=None):
    """Send message to worker with necessary information to run a crypten party.
    Add response to return_values dictionary.

    Args:
        worker (BaseWorker): worker to send the message to.
        rank (int): rank of the crypten party.
        msg (CryptenInitMessage): message containing the rank, world_size, master_addr and master_port.
        return_values (dict): dictionnary holding return values of workers.
        model: crypten model to unpack parameters to (if received).
    """

    response = worker.send_msg(msg, worker)
    return_values[rank] = utils.unpack_values(response.object, model)
示例#6
0
def test_pack_crypten_model():
    dummy_input = th.rand(1, 28 * 28)
    expected_crypten_model = crypten.nn.from_pytorch(ExampleNet(), dummy_input)
    expected_out = expected_crypten_model(dummy_input)

    packed = utils.pack_values(expected_crypten_model)

    # zero all model's parameters
    with th.no_grad():
        for p in expected_crypten_model.parameters():
            assert isinstance(p, th.Tensor)
            p.set_(th.zeros_like(p))

    crypten_model = utils.unpack_values(packed, model=expected_crypten_model)

    out = crypten_model(dummy_input)
    assert th.all(expected_out == out)
示例#7
0
def test_pack_other():
    expected_value = utils.pack_values(42)
    assert 42 == utils.unpack_values(expected_value)
示例#8
0
        def wrapper(*args, **kwargs):
            # TODO:
            # - check if workers are reachable / they can handle the computation
            # - check return code of processes for possible failure

            if model is not None:
                if not isinstance(model, th.nn.Module):
                    raise TypeError("model must be a torch.nn.Module")
                if dummy_input is None:
                    raise ValueError(
                        "must provide dummy_input when model is set")
                if not isinstance(dummy_input, th.Tensor):
                    raise TypeError("dummy_input must be a torch.Tensor")
                onnx_model = utils.pytorch_to_onnx(model, dummy_input)
            else:
                onnx_model = None

            crypten_model = None if onnx_model is None else utils.onnx_to_crypten(
                onnx_model)

            world_size = len(workers) + 1
            return_values = {rank: None for rank in range(world_size)}

            if isinstance(func, sy.Plan):
                using_plan = True
                plan = func

                # This is needed because at building we use a set of methods defined in syft (ex: load)
                hook_plan_building()
                crypten.init()
                plan.build()
                crypten.uninit()
                unhook_plan_building()

                # Mark the plan so the other workers will use that tag to retrieve the plan
                plan.tags = ["crypten_plan"]

                for worker in workers:
                    plan.send(worker)

                jail_or_plan = plan

            else:  # func
                using_plan = False
                jail_runner = jail.JailRunner(func=func, model=crypten_model)
                ser_jail_runner = jail.JailRunner.simplify(jail_runner)

                jail_or_plan = jail_runner

            rank_to_worker_id = dict(
                zip(range(1,
                          len(workers) + 1),
                    [worker.id for worker in workers]))

            sy.local_worker._set_rank_to_worker_id(rank_to_worker_id)

            # Start local party
            process, queue = _new_party(jail_or_plan, 0, world_size,
                                        master_addr, master_port, (), {})

            was_initialized = DistributedCommunicator.is_initialized()
            if was_initialized:
                crypten.uninit()
            process.start()

            # Run TTP if required
            # TODO: run ttp in a specified worker
            if crypten.mpc.ttp_required():
                ttp_process, _ = _new_party(
                    crypten.mpc.provider.TTPServer,
                    world_size,
                    world_size,
                    master_addr,
                    master_port,
                    (),
                    {},
                )
                ttp_process.start()

            # Send messages to other workers so they start their parties
            threads = []
            for i in range(len(workers)):
                rank = i + 1
                if using_plan:
                    msg = CryptenInitPlan((rank_to_worker_id, world_size,
                                           master_addr, master_port))
                else:  # jail
                    msg = CryptenInitJail(
                        (rank_to_worker_id, world_size, master_addr,
                         master_port),
                        ser_jail_runner,
                        onnx_model,
                    )
                thread = threading.Thread(target=_send_party_info,
                                          args=(workers[i], rank, msg,
                                                return_values))
                thread.start()
                threads.append(thread)

            # Wait for local party and sender threads
            # Joining the process blocks! But queue.get() can also wait for the party
            # and it works fine.
            # process.join() -> blocks
            local_party_result = queue.get()
            return_values[0] = utils.unpack_values(local_party_result,
                                                   crypten_model)
            for thread in threads:
                thread.join()
            if was_initialized:
                crypten.init()

            return return_values