Beispiel #1
0
    def __call__(self, protocol_function):
        # create the roles present in decorator
        roles = {
            role_id:
            Role(worker=VirtualWorker(id=role_id, hook=sy.local_worker.hook))
            for role_id in self.role_names
        }
        for role_id, state_tensors in self.states.items():
            for tensor in state_tensors:
                roles[role_id].register_state_tensor(tensor)

        protocol = Protocol(
            name=protocol_function.__name__,
            forward_func=protocol_function,
            roles=roles,
            id=sy.ID_PROVIDER.pop(),
            owner=sy.local_worker,
        )

        try:
            protocol.build()
        except TypeError as e:
            raise ValueError(
                "Automatic build using @func2protocol failed!\nCheck that:\n"
                " - you have provided the correct number of shapes in args_shape\n"
                " - you have no simple numbers like int or float as args. If you do "
                "so, please consider using a tensor instead.")

        return protocol
Beispiel #2
0
    def __call__(self, protocol_function):
        # create the roles present in decorator
        roles = {role_id: Role() for role_id in self.args_shape.keys()}

        protocol = Protocol(
            name=protocol_function.__name__,
            forward_func=protocol_function,
            roles=roles,
            id=sy.ID_PROVIDER.pop(),
            owner=sy.local_worker,
        )

        # Build the protocol automatically
        # TODO We can always build automatically, can't we? Except if workers doesn't have
        # tensors yet in store. Do we handle that?
        if self.args_shape:
            try:
                protocol.build()
            except TypeError as e:
                raise ValueError(
                    "Automatic build using @func2protocol failed!\nCheck that:\n"
                    " - you have provided the correct number of shapes in args_shape\n"
                    " - you have no simple numbers like int or float as args. If you do "
                    "so, please consider using a tensor instead.")
        return protocol
Beispiel #3
0
    def __init__(
        self,
        name: str = None,
        include_state: bool = False,
        is_built: bool = False,
        forward_func=None,
        state_tensors=None,
        role: Role = None,
        # General kwargs
        id: Union[str, int] = None,
        owner: "sy.workers.BaseWorker" = None,
        tags: List[str] = None,
        description: str = None,
    ):
        AbstractObject.__init__(self, id, owner, tags, description, child=None)

        # Plan instance info
        self.name = name or self.__class__.__name__

        self.role = role or Role(state_tensors=state_tensors, owner=owner)

        self.include_state = include_state
        self.is_built = is_built
        self.torchscript = None

        # The plan has not been sent so it has no reference to remote locations
        self.pointers = dict()

        if not hasattr(self, "forward"):
            self.forward = forward_func or None

        self.__name__ = self.__repr__(
        )  # For PyTorch jit tracing compatibility
Beispiel #4
0
    def _restore_placeholders(action: ComputationAction):
        """Converts PlaceholderId's to PlaceHolder in an Action"""
        def wrap_in_placeholder(ph_id):
            return PlaceHolder(id=ph_id)

        action.target = Role.nested_object_traversal(action.target,
                                                     wrap_in_placeholder,
                                                     PlaceholderId)
        action.args = Role.nested_object_traversal(action.args,
                                                   wrap_in_placeholder,
                                                   PlaceholderId)
        action.kwargs = Role.nested_object_traversal(action.kwargs,
                                                     wrap_in_placeholder,
                                                     PlaceholderId)
        action.return_ids = Role.nested_object_traversal(
            action.return_ids, wrap_in_placeholder, PlaceholderId)
        return action
Beispiel #5
0
def test_register_computation_action():
    role = Role()
    placeholder = PlaceHolder()
    target = torch.ones([1])

    action = ("__add__", target, (), {})

    role.register_action((action, placeholder), ComputationAction)

    assert len(role.actions) == 1

    registered = role.actions[0]

    assert isinstance(registered, ComputationAction)
    assert registered.name == "__add__"
    assert registered.target == target
    assert registered.args == ()
    assert registered.kwargs == {}
    assert registered.return_ids == (placeholder.id, )
Beispiel #6
0
    def __init__(
        self,
        name: str = None,
        include_state: bool = False,
        is_built: bool = False,
        forward_func=None,
        state_tensors=[],
        role: Role = None,
        # General kwargs
        id: Union[str, int] = None,
        owner: "sy.workers.BaseWorker" = None,
        tags: List[str] = None,
        input_types: list = None,
        description: str = None,
    ):
        AbstractObject.__init__(self, id, owner, tags, description, child=None)

        # Plan instance info
        self.name = name or self.__class__.__name__

        self.role = role or Role()

        if role is None:
            for st in state_tensors:
                self.role.register_state_tensor(st)

        self.include_state = include_state
        self.is_building = False
        self.state_attributes = {}
        self.is_built = is_built
        self.torchscript = None
        self.input_types = input_types
        self.tracing = False

        # The plan has not been sent so it has no reference to remote locations
        self.pointers = dict()

        if not hasattr(self, "forward"):
            self.forward = forward_func or None
        """
        When we use methods defined in a framework (like: torch.randn) we have a framework
        wrapper that helps as register and keep track of what methods are called
        With the below lines, we "register" what frameworks we have support to handle
        """
        self.wrapped_framework = {}
        for f_name, f_packages in framework_packages.items():
            self.wrapped_framework[f_name] = FrameworkWrapper(
                f_packages, self.role, self.owner)

        self.__name__ = self.__repr__(
        )  # For PyTorch jit tracing compatibility

        # List of available translations
        self.translations = []
Beispiel #7
0
    def translate_action(self, action: ComputationAction, to_framework: str,
                         role: Role):
        """Uses threepio to perform command level translation given a specific action"""
        self._restore_placeholders(action)
        threepio = Threepio(self.plan.base_framework, to_framework, None)
        function_name = action.name.split(".")[-1]
        args = action.args if action.target is None else (action.target,
                                                          *action.args)
        translated_cmds = threepio.translate(
            Command(function_name, args, action.kwargs))

        if len(translated_cmds) > 1:
            return self.translate_multi_action(translated_cmds, action, role)

        for cmd in translated_cmds:
            role_action = (
                (".".join(cmd.attrs), None, tuple(cmd.args), cmd.kwargs),
                action.return_ids,
            )
            role.register_action(role_action, ComputationAction)
Beispiel #8
0
    def translate_multi_action(self, translated_cmds: List[Command],
                               action: ComputationAction, role: Role):
        cmd_config = translated_cmds.pop(0)
        store = {}
        actions = []
        for cmd in translated_cmds:
            # Create local store of placeholders
            if cmd.placeholder_output is not None:
                store[cmd.placeholder_output] = PlaceHolder(role=role)

            for i, arg in enumerate(cmd.args):
                if type(arg) == pythreepio.command.Placeholder:
                    # Replace any threepio placeholders w/ pysyft placeholders
                    cmd.args[i] = store.get(arg.key, None)

            # Create action informat needed for role's register_action method
            role_action = (
                (".".join(cmd.attrs), None, tuple(cmd.args), cmd.kwargs),
                store[cmd.placeholder_output],
            )
            role.register_action(role_action, ComputationAction)
Beispiel #9
0
    def __init__(
        self,
        name: str = None,
        include_state: bool = False,
        is_built: bool = False,
        forward_func=None,
        state_tensors=[],
        role: Role = None,
        # General kwargs
        id: Union[str, int] = None,
        owner: "sy.workers.BaseWorker" = None,
        tags: List[str] = None,
        input_types: list = None,
        description: str = None,
        roles: Dict[str, Role] = None,
        base_framework: str = TranslationTarget.PYTORCH.value,
    ):
        super().__init__(id, owner, tags, description, child=None)

        # Plan instance info
        self.name = name or self.__class__.__name__

        self.role = role or Role()

        if role is None:
            for st in state_tensors:
                self.role.register_state_tensor(st)

        self.include_state = include_state
        self.is_building = False
        self.state_attributes = {}
        self.is_built = is_built
        self.torchscript = None
        self.input_types = input_types
        self.validate_input_types = True
        self.tracing = False
        self._base_framework = base_framework
        self.roles = roles or {base_framework: self.role}

        # The plan has not been sent so it has no reference to remote locations
        self.pointers = {}

        if not hasattr(self, "forward"):
            self.forward = forward_func or None

        self.__name__ = self.__repr__(
        )  # For PyTorch jit tracing compatibility

        # List of available translations
        self.translations = []
Beispiel #10
0
def test_reset():
    role = Role()
    placeholder = PlaceHolder()
    target = torch.ones([1])

    action = ("get", target, (), {})

    role.register_action((action, placeholder), CommunicationAction)
    role.placeholders = {"ph_id1": PlaceHolder(), "ph_id2": PlaceHolder()}
    role.input_placeholder_ids = ("input1", "input2")
    role.output_placeholder_ids = ("output1", )

    assert len(role.actions) == 1
    assert len(role.placeholders) == 2
    assert role.input_placeholder_ids == ("input1", "input2")
    assert role.output_placeholder_ids == ("output1", )

    role.reset()

    assert len(role.actions) == 0
    assert len(role.placeholders) == 0
    assert role.input_placeholder_ids == ()
    assert role.output_placeholder_ids == ()
Beispiel #11
0
 def get_role_for_owner(self, owner):
     if owner.id not in self.roles:
         self.roles[owner.id] = Role()
     return self.roles[owner.id]