def test_plan_execution(client: sy.VirtualMachineClient) -> None: tensor_pointer1 = th.tensor([1, 2, 3]).send(client) tensor_pointer2 = th.tensor([4, 5, 6]).send(client) tensor_pointer3 = th.tensor([7, 8, 9]).send(client) result_tensor_pointer1 = th.tensor([0, 0, 0]).send(client) result_tensor_pointer2 = th.tensor([0, 0, 0]).send(client) result1_uid = result_tensor_pointer1.id_at_location result2_uid = result_tensor_pointer2.id_at_location a1 = RunClassMethodAction( path="torch.Tensor.add", _self=tensor_pointer1, args=[tensor_pointer2], kwargs={}, id_at_location=result1_uid, address=Address(), msg_id=UID(), ) a2 = RunClassMethodAction( path="torch.Tensor.add", _self=result_tensor_pointer1, args=[tensor_pointer3], kwargs={}, id_at_location=result2_uid, address=Address(), msg_id=UID(), ) plan = Plan([a1, a2]) plan_pointer = plan.send(client) plan_pointer() expected_tensor1 = th.tensor([5, 7, 9]) expected_tensor2 = th.tensor([12, 15, 18]) assert all(expected_tensor1 == result_tensor_pointer1.get()) assert all(expected_tensor2 == result_tensor_pointer2.get())
def test_plan_serialization(client: sy.VirtualMachineClient) -> None: # cumbersome way to get a pointer as input for our actions, # there is probably a better/shorter way t = th.tensor([1, 2, 3]) tensor_pointer = t.send(client) # define actions a1 = GetObjectAction(id_at_location=UID(), address=Address(), reply_to=Address(), msg_id=UID()) a2 = RunFunctionOrConstructorAction( path="torch.Tensor.add", args=tuple(), kwargs={}, id_at_location=UID(), address=Address(), msg_id=UID(), ) a3 = RunClassMethodAction( path="torch.Tensor.add", _self=tensor_pointer, args=[], kwargs={}, id_at_location=UID(), address=Address(), msg_id=UID(), ) a4 = GarbageCollectObjectAction(id_at_location=UID(), address=Address()) a5 = EnumAttributeAction(path="", id_at_location=UID(), address=Address()) a6 = GetOrSetPropertyAction( path="", _self=tensor_pointer, id_at_location=UID(), address=Address(), args=[], kwargs={}, action=PropertyActions.GET, ) a7 = GetSetStaticAttributeAction( path="", id_at_location=UID(), address=Address(), action=StaticAttributeAction.GET, ) a8 = SaveObjectAction(obj=StorableObject(id=UID(), data=t), address=Address()) # define plan plan = Plan([a1, a2, a3, a4, a5, a6, a7, a8]) # serialize / deserialize blob = serialize(plan) plan_reconstructed = sy.deserialize(blob=blob) # test assert isinstance(plan_reconstructed, Plan) assert all(isinstance(a, Action) for a in plan_reconstructed.actions)
def test_plan_batched_execution(client: sy.VirtualMachineClient) -> None: # placeholders for our input input_tensor_pointer1 = th.tensor([0, 0]).send(client) input_tensor_pointer2 = th.tensor([0, 0]).send(client) # tensors in our model model_tensor_pointer1 = th.tensor([1, 2]).send(client) model_tensor_pointer2 = th.tensor([3, 4]).send(client) # placeholders for intermediate results result_tensor_pointer1 = th.tensor([0, 0]).send(client) result_tensor_pointer2 = th.tensor([0, 0]).send(client) result_tensor_pointer3 = th.tensor([0, 0]).send(client) # define plan a1 = RunClassMethodAction( path="torch.Tensor.mul", _self=input_tensor_pointer1, args=[model_tensor_pointer1], kwargs={}, id_at_location=result_tensor_pointer1.id_at_location, address=Address(), msg_id=UID(), ) a2 = RunClassMethodAction( path="torch.Tensor.add", _self=result_tensor_pointer1, args=[model_tensor_pointer2], kwargs={}, id_at_location=result_tensor_pointer2.id_at_location, address=Address(), msg_id=UID(), ) a3 = RunFunctionOrConstructorAction( path="torch.eq", args=[result_tensor_pointer2, input_tensor_pointer2], kwargs={}, id_at_location=result_tensor_pointer3.id_at_location, address=Address(), msg_id=UID(), ) plan = Plan([a1, a2, a3], inputs={ "x": input_tensor_pointer1, "y": input_tensor_pointer2 }) plan_pointer = plan.send(client) # Test # x is random input, y is the expected model(x) x_batches = [(th.tensor([1, 1]) + i).send(client) for i in range(2)] y_batches = [((th.tensor([1, 1]) + i) * th.tensor([1, 2]) + th.tensor([3, 4])).send(client) for i in range(2)] for x, y in zip(x_batches, y_batches): plan_pointer(x=x, y=y) # checks if (model(x) == y) == [True, True] assert all(result_tensor_pointer3.get(delete_obj=False))