Example #1
0
    def run_crypten_party_jail(self, msg: CryptenInitJail):  # pragma: no cover
        """Run crypten party according to the information received.

        Args:
            message (CryptenInitJail): should contain the rank, world_size,
                                    master_addr and master_port.

        Returns:
            An ObjectMessage containing the return value of the crypten function computed.
        """

        rank_to_worker_id, world_size, master_addr, master_port = msg.crypten_context

        cid = syft.ID_PROVIDER.pop()
        syft_crypten.RANK_TO_WORKER_ID[cid] = rank_to_worker_id

        ser_func = msg.jail_runner
        onnx_model = msg.model
        crypten_model = None if onnx_model is None else utils.onnx_to_crypten(
            onnx_model)
        jail_runner = JailRunner.detail(ser_func, model=crypten_model)

        rank = self._current_rank(rank_to_worker_id)
        assert rank is not None

        return_value = run_party(cid, jail_runner, rank, world_size,
                                 master_addr, master_port, (), {})
        # remove rank to id transaltion dict
        del syft_crypten.RANK_TO_WORKER_ID[cid]

        return ObjectMessage(return_value)
Example #2
0
def test_send_msg():
    """Tests sending a message with a specific ID

    This is a simple test to ensure that the BaseWorker interface
    can properly send/receive a message containing a tensor.
    """

    # get pointer to local worker
    me = sy.torch.hook.local_worker

    # pending time to simulate lantency (optional)
    me.message_pending_time = 0.1

    # create a new worker (to send the object to)
    worker_id = sy.ID_PROVIDER.pop()
    bob = VirtualWorker(sy.torch.hook, id=f"bob{worker_id}")

    # initialize the object and save it's id
    obj = torch.Tensor([100, 100])
    obj_id = obj.id

    # Send data to bob
    start_time = time()
    me.send_msg(ObjectMessage(obj), bob)
    elapsed_time = time() - start_time

    me.message_pending_time = 0

    # ensure that object is now on bob's machine
    assert obj_id in bob.object_store._objects
    # ensure that object was sent 0.1 secs later
    assert elapsed_time > 0.1
Example #3
0
    def send_obj(self, obj: object, location: "BaseWorker"):
        """Send a torch object to a worker.

        Args:
            obj: A torch Tensor or Variable object to be sent.
            location: A BaseWorker instance indicating the worker which should
                receive the object.
        """
        return self.send_msg(ObjectMessage(obj), location)
Example #4
0
    def run_crypten_party_plan(
            self, msg: CryptenInitPlan) -> ObjectMessage:  # pragma: no cover
        """Run crypten party according to the information received.

        Args:
            msg (CryptenInitPlan): should contain the rank_to_worker_id, world_size,
                                master_addr and master_port.

        Returns:
            An ObjectMessage containing the return value of the crypten function computed.
        """

        rank_to_worker_id, world_size, master_addr, master_port = msg.crypten_context

        cid = syft.ID_PROVIDER.pop()
        syft_crypten.RANK_TO_WORKER_ID[cid] = rank_to_worker_id

        onnx_model = msg.model
        crypten_model = None if onnx_model is None else utils.onnx_to_crypten(
            onnx_model)

        # TODO Change this, we need a way to handle multiple plan definitions
        plans = self.worker.search("crypten_plan")
        if len(plans) != 1:
            raise ValueError(
                f"Error: {len(plans)} plans found. There should be only 1.")

        plan = plans[0].get()

        rank = self._current_rank(rank_to_worker_id)
        if rank is None:
            raise ValueError("Current rank can't be None")

        if crypten_model:
            args = (crypten_model, )
        else:
            args = ()

        return_value = run_party(cid, plan, rank, world_size, master_addr,
                                 master_port, args, {})
        # remove rank to id transaltion dict
        del syft_crypten.RANK_TO_WORKER_ID[cid]

        # Delete the plan at the end of the computation
        self.worker.de_register_obj(plan)

        return ObjectMessage(return_value)
Example #5
0
def test_recv_msg():
    """Tests the recv_msg command with 2 tests

    The first test uses recv_msg to send an object to alice.

    The second test uses recv_msg to request the object
    previously sent to alice."""

    # TEST 1: send tensor to alice

    # create a worker to send data to
    worker_id = sy.ID_PROVIDER.pop()
    alice = VirtualWorker(sy.torch.hook, id=f"alice{worker_id}")

    # create object to send
    obj = torch.Tensor([100, 100])

    # create/serialize message
    message = ObjectMessage(obj)
    bin_msg = serde.serialize(message)

    # have alice receive message
    alice.recv_msg(bin_msg)

    # ensure that object is now in alice's registry
    assert obj.id in alice.object_store._objects

    # Test 2: get tensor back from alice

    # Create message: Get tensor from alice
    message = ObjectRequestMessage(obj.id, None, "")

    # serialize message
    bin_msg = serde.serialize(message)

    # call receive message on alice
    resp = alice.recv_msg(bin_msg)

    obj_2 = sy.serde.deserialize(resp)

    # assert that response is correct type
    assert type(resp) == bytes

    # ensure that the object we receive is correct
    assert obj_2.id == obj.id
Example #6
0
def test_send_msg():
    """Tests sending a message with a specific ID

    This is a simple test to ensure that the BaseWorker interface
    can properly send/receive a message containing a tensor.
    """

    # get pointer to local worker
    me = sy.torch.hook.local_worker

    # create a new worker (to send the object to)
    worker_id = sy.ID_PROVIDER.pop()
    bob = VirtualWorker(sy.torch.hook, id=f"bob{worker_id}")

    # initialize the object and save it's id
    obj = torch.Tensor([100, 100])
    obj_id = obj.id

    # Send data to bob
    me.send_msg(ObjectMessage(obj), bob)

    # ensure that object is now on bob's machine
    assert obj_id in bob._objects