Пример #1
0
    def run_crypten_party_jail(self, msg: CryptenInitJail):  # pragma: no cover
        """Run crypten party according to the information received.

        Args:
            message (CryptenInitJail): should contain the rank, world_size,
                                    master_addr and master_port.

        Returns:
            An ObjectMessage containing the return value of the crypten function computed.
        """

        rank_to_worker_id, world_size, master_addr, master_port = msg.crypten_context

        cid = syft.ID_PROVIDER.pop()
        syft_crypten.RANK_TO_WORKER_ID[cid] = rank_to_worker_id

        ser_func = msg.jail_runner
        onnx_model = msg.model
        crypten_model = None if onnx_model is None else utils.onnx_to_crypten(
            onnx_model)
        jail_runner = JailRunner.detail(ser_func, model=crypten_model)

        rank = self._current_rank(rank_to_worker_id)
        assert rank is not None

        return_value = run_party(cid, jail_runner, rank, world_size,
                                 master_addr, master_port, (), {})
        # remove rank to id transaltion dict
        del syft_crypten.RANK_TO_WORKER_ID[cid]

        return ObjectMessage(return_value)
Пример #2
0
def test_serialize_models():
    class ExampleNet(th.nn.Module):
        def __init__(self):
            super(ExampleNet, self).__init__()
            self.fc1 = th.nn.Linear(1024, 100)
            self.fc2 = th.nn.Linear(
                100, 2
            )  # For binary classification, final layer needs only 2 outputs

        def forward(self, x):
            out = self.fc1(x)
            out = th.nn.functional.relu(out)
            out = self.fc2(out)
            return out

    dummy_input = th.ones(1, 1024)
    example_net = ExampleNet()

    expected_output = example_net(dummy_input)

    onnx_bytes = utils.pytorch_to_onnx(example_net, dummy_input)
    crypten_model = utils.onnx_to_crypten(onnx_bytes)
    output = crypten_model(dummy_input)

    assert th.allclose(expected_output, output)
Пример #3
0
    def run_crypten_party_plan(
            self, msg: CryptenInitPlan) -> ObjectMessage:  # pragma: no cover
        """Run crypten party according to the information received.

        Args:
            msg (CryptenInitPlan): should contain the rank_to_worker_id, world_size,
                                master_addr and master_port.

        Returns:
            An ObjectMessage containing the return value of the crypten function computed.
        """

        rank_to_worker_id, world_size, master_addr, master_port = msg.crypten_context

        cid = syft.ID_PROVIDER.pop()
        syft_crypten.RANK_TO_WORKER_ID[cid] = rank_to_worker_id

        onnx_model = msg.model
        crypten_model = None if onnx_model is None else utils.onnx_to_crypten(
            onnx_model)

        # TODO Change this, we need a way to handle multiple plan definitions
        plans = self.worker.search("crypten_plan")
        if len(plans) != 1:
            raise ValueError(
                f"Error: {len(plans)} plans found. There should be only 1.")

        plan = plans[0].get()

        rank = self._current_rank(rank_to_worker_id)
        if rank is None:
            raise ValueError("Current rank can't be None")

        if crypten_model:
            args = (crypten_model, )
        else:
            args = ()

        return_value = run_party(cid, plan, rank, world_size, master_addr,
                                 master_port, args, {})
        # remove rank to id transaltion dict
        del syft_crypten.RANK_TO_WORKER_ID[cid]

        # Delete the plan at the end of the computation
        self.worker.de_register_obj(plan)

        return ObjectMessage(return_value)
Пример #4
0
        def wrapper(*args, **kwargs):
            # TODO:
            # - check if workers are reachable / they can handle the computation
            # - check return code of processes for possible failure

            if len(workers) != len(set(worker.id
                                       for worker in workers)):  # noqa: C401
                raise RuntimeError(
                    "found workers with same ID but IDs must be unique")

            if model is not None:
                if not isinstance(model, th.nn.Module):
                    raise TypeError("model must be a torch.nn.Module")
                if dummy_input is None:
                    raise ValueError(
                        "must provide dummy_input when model is set")
                if not isinstance(dummy_input, th.Tensor):
                    raise TypeError("dummy_input must be a torch.Tensor")
                onnx_model = utils.pytorch_to_onnx(model, dummy_input)
            else:
                onnx_model = None

            crypten_model = None if onnx_model is None else utils.onnx_to_crypten(
                onnx_model)

            world_size = len(workers)
            manager = multiprocessing.Manager()
            return_values = manager.dict(
                {rank: None
                 for rank in range(world_size)})

            rank_to_worker_id = dict(
                zip(range(0, len(workers)), [worker.id for worker in workers]))

            # TODO: run ttp in a specified worker
            # if crypten.mpc.ttp_required():
            #     ttp_process, _ = _new_party(
            #         crypten.mpc.provider.TTPServer,
            #         world_size,
            #         world_size,
            #         master_addr,
            #         master_port,
            #         (),
            #         {},
            #     )
            #     ttp_process.start()

            if isinstance(func, sy.Plan):
                plan = func

                # This is needed because at building we use a set of methods defined in syft
                # (ex: load)
                hook_plan_building()
                was_initialized = DistributedCommunicator.is_initialized()
                if not was_initialized:
                    crypten.init()

                # We can build the plan only using a crypten model such that the actions
                # traced inside the plan would know about it's existance
                if crypten_model is None:
                    plan.build()
                else:
                    plan.build(crypten_model)

                if not was_initialized:
                    crypten.uninit()
                unhook_plan_building()

                # Mark the plan so the other workers will use that tag to retrieve the plan
                plan.tags = ["crypten_plan"]

                for worker in workers:
                    plan.send(worker)

                msg = CryptenInitPlan(
                    (rank_to_worker_id, world_size, master_addr, master_port),
                    onnx_model)

            else:  # func
                jail_runner = jail.JailRunner(func=func)
                ser_jail_runner = jail.JailRunner.simplify(jail_runner)

                msg = CryptenInitJail(
                    (rank_to_worker_id, world_size, master_addr, master_port),
                    ser_jail_runner,
                    onnx_model,
                )

            # Send messages to other workers so they start their parties
            threads = []
            for i in range(len(workers)):
                rank = i
                thread = multiprocessing.Process(
                    target=_send_party_info,
                    args=(workers[i], rank, msg, return_values, crypten_model),
                )
                thread.start()
                threads.append(thread)

            # wait for workers running the parties return a response
            for thread in threads:
                thread.join()

            return return_values
Пример #5
0
        def wrapper(*args, **kwargs):
            # TODO:
            # - check if workers are reachable / they can handle the computation
            # - check return code of processes for possible failure

            if model is not None:
                if not isinstance(model, th.nn.Module):
                    raise TypeError("model must be a torch.nn.Module")
                if dummy_input is None:
                    raise ValueError(
                        "must provide dummy_input when model is set")
                if not isinstance(dummy_input, th.Tensor):
                    raise TypeError("dummy_input must be a torch.Tensor")
                onnx_model = utils.pytorch_to_onnx(model, dummy_input)
            else:
                onnx_model = None

            crypten_model = None if onnx_model is None else utils.onnx_to_crypten(
                onnx_model)

            world_size = len(workers) + 1
            return_values = {rank: None for rank in range(world_size)}

            if isinstance(func, sy.Plan):
                using_plan = True
                plan = func

                # This is needed because at building we use a set of methods defined in syft (ex: load)
                hook_plan_building()
                crypten.init()
                plan.build()
                crypten.uninit()
                unhook_plan_building()

                # Mark the plan so the other workers will use that tag to retrieve the plan
                plan.tags = ["crypten_plan"]

                for worker in workers:
                    plan.send(worker)

                jail_or_plan = plan

            else:  # func
                using_plan = False
                jail_runner = jail.JailRunner(func=func, model=crypten_model)
                ser_jail_runner = jail.JailRunner.simplify(jail_runner)

                jail_or_plan = jail_runner

            rank_to_worker_id = dict(
                zip(range(1,
                          len(workers) + 1),
                    [worker.id for worker in workers]))

            sy.local_worker._set_rank_to_worker_id(rank_to_worker_id)

            # Start local party
            process, queue = _new_party(jail_or_plan, 0, world_size,
                                        master_addr, master_port, (), {})

            was_initialized = DistributedCommunicator.is_initialized()
            if was_initialized:
                crypten.uninit()
            process.start()

            # Run TTP if required
            # TODO: run ttp in a specified worker
            if crypten.mpc.ttp_required():
                ttp_process, _ = _new_party(
                    crypten.mpc.provider.TTPServer,
                    world_size,
                    world_size,
                    master_addr,
                    master_port,
                    (),
                    {},
                )
                ttp_process.start()

            # Send messages to other workers so they start their parties
            threads = []
            for i in range(len(workers)):
                rank = i + 1
                if using_plan:
                    msg = CryptenInitPlan((rank_to_worker_id, world_size,
                                           master_addr, master_port))
                else:  # jail
                    msg = CryptenInitJail(
                        (rank_to_worker_id, world_size, master_addr,
                         master_port),
                        ser_jail_runner,
                        onnx_model,
                    )
                thread = threading.Thread(target=_send_party_info,
                                          args=(workers[i], rank, msg,
                                                return_values))
                thread.start()
                threads.append(thread)

            # Wait for local party and sender threads
            # Joining the process blocks! But queue.get() can also wait for the party
            # and it works fine.
            # process.join() -> blocks
            local_party_result = queue.get()
            return_values[0] = utils.unpack_values(local_party_result,
                                                   crypten_model)
            for thread in threads:
                thread.join()
            if was_initialized:
                crypten.init()

            return return_values
Пример #6
0
 def to_crypten(self):
     return utils.onnx_to_crypten(self.serialized_model)