def __call__(self, protocol_function): # create the roles present in decorator roles = { role_id: Role(worker=VirtualWorker(id=role_id, hook=sy.local_worker.hook)) for role_id in self.role_names } for role_id, state_tensors in self.states.items(): for tensor in state_tensors: roles[role_id].register_state_tensor(tensor) protocol = Protocol( name=protocol_function.__name__, forward_func=protocol_function, roles=roles, id=sy.ID_PROVIDER.pop(), owner=sy.local_worker, ) try: protocol.build() except TypeError as e: raise ValueError( "Automatic build using @func2protocol failed!\nCheck that:\n" " - you have provided the correct number of shapes in args_shape\n" " - you have no simple numbers like int or float as args. If you do " "so, please consider using a tensor instead.") return protocol
def __call__(self, protocol_function): # create the roles present in decorator roles = {role_id: Role() for role_id in self.args_shape.keys()} protocol = Protocol( name=protocol_function.__name__, forward_func=protocol_function, roles=roles, id=sy.ID_PROVIDER.pop(), owner=sy.local_worker, ) # Build the protocol automatically # TODO We can always build automatically, can't we? Except if workers doesn't have # tensors yet in store. Do we handle that? if self.args_shape: try: protocol.build() except TypeError as e: raise ValueError( "Automatic build using @func2protocol failed!\nCheck that:\n" " - you have provided the correct number of shapes in args_shape\n" " - you have no simple numbers like int or float as args. If you do " "so, please consider using a tensor instead.") return protocol
def __init__( self, name: str = None, include_state: bool = False, is_built: bool = False, forward_func=None, state_tensors=None, role: Role = None, # General kwargs id: Union[str, int] = None, owner: "sy.workers.BaseWorker" = None, tags: List[str] = None, description: str = None, ): AbstractObject.__init__(self, id, owner, tags, description, child=None) # Plan instance info self.name = name or self.__class__.__name__ self.role = role or Role(state_tensors=state_tensors, owner=owner) self.include_state = include_state self.is_built = is_built self.torchscript = None # The plan has not been sent so it has no reference to remote locations self.pointers = dict() if not hasattr(self, "forward"): self.forward = forward_func or None self.__name__ = self.__repr__( ) # For PyTorch jit tracing compatibility
def _restore_placeholders(action: ComputationAction): """Converts PlaceholderId's to PlaceHolder in an Action""" def wrap_in_placeholder(ph_id): return PlaceHolder(id=ph_id) action.target = Role.nested_object_traversal(action.target, wrap_in_placeholder, PlaceholderId) action.args = Role.nested_object_traversal(action.args, wrap_in_placeholder, PlaceholderId) action.kwargs = Role.nested_object_traversal(action.kwargs, wrap_in_placeholder, PlaceholderId) action.return_ids = Role.nested_object_traversal( action.return_ids, wrap_in_placeholder, PlaceholderId) return action
def test_register_computation_action(): role = Role() placeholder = PlaceHolder() target = torch.ones([1]) action = ("__add__", target, (), {}) role.register_action((action, placeholder), ComputationAction) assert len(role.actions) == 1 registered = role.actions[0] assert isinstance(registered, ComputationAction) assert registered.name == "__add__" assert registered.target == target assert registered.args == () assert registered.kwargs == {} assert registered.return_ids == (placeholder.id, )
def __init__( self, name: str = None, include_state: bool = False, is_built: bool = False, forward_func=None, state_tensors=[], role: Role = None, # General kwargs id: Union[str, int] = None, owner: "sy.workers.BaseWorker" = None, tags: List[str] = None, input_types: list = None, description: str = None, ): AbstractObject.__init__(self, id, owner, tags, description, child=None) # Plan instance info self.name = name or self.__class__.__name__ self.role = role or Role() if role is None: for st in state_tensors: self.role.register_state_tensor(st) self.include_state = include_state self.is_building = False self.state_attributes = {} self.is_built = is_built self.torchscript = None self.input_types = input_types self.tracing = False # The plan has not been sent so it has no reference to remote locations self.pointers = dict() if not hasattr(self, "forward"): self.forward = forward_func or None """ When we use methods defined in a framework (like: torch.randn) we have a framework wrapper that helps as register and keep track of what methods are called With the below lines, we "register" what frameworks we have support to handle """ self.wrapped_framework = {} for f_name, f_packages in framework_packages.items(): self.wrapped_framework[f_name] = FrameworkWrapper( f_packages, self.role, self.owner) self.__name__ = self.__repr__( ) # For PyTorch jit tracing compatibility # List of available translations self.translations = []
def translate_action(self, action: ComputationAction, to_framework: str, role: Role): """Uses threepio to perform command level translation given a specific action""" self._restore_placeholders(action) threepio = Threepio(self.plan.base_framework, to_framework, None) function_name = action.name.split(".")[-1] args = action.args if action.target is None else (action.target, *action.args) translated_cmds = threepio.translate( Command(function_name, args, action.kwargs)) if len(translated_cmds) > 1: return self.translate_multi_action(translated_cmds, action, role) for cmd in translated_cmds: role_action = ( (".".join(cmd.attrs), None, tuple(cmd.args), cmd.kwargs), action.return_ids, ) role.register_action(role_action, ComputationAction)
def translate_multi_action(self, translated_cmds: List[Command], action: ComputationAction, role: Role): cmd_config = translated_cmds.pop(0) store = {} actions = [] for cmd in translated_cmds: # Create local store of placeholders if cmd.placeholder_output is not None: store[cmd.placeholder_output] = PlaceHolder(role=role) for i, arg in enumerate(cmd.args): if type(arg) == pythreepio.command.Placeholder: # Replace any threepio placeholders w/ pysyft placeholders cmd.args[i] = store.get(arg.key, None) # Create action informat needed for role's register_action method role_action = ( (".".join(cmd.attrs), None, tuple(cmd.args), cmd.kwargs), store[cmd.placeholder_output], ) role.register_action(role_action, ComputationAction)
def __init__( self, name: str = None, include_state: bool = False, is_built: bool = False, forward_func=None, state_tensors=[], role: Role = None, # General kwargs id: Union[str, int] = None, owner: "sy.workers.BaseWorker" = None, tags: List[str] = None, input_types: list = None, description: str = None, roles: Dict[str, Role] = None, base_framework: str = TranslationTarget.PYTORCH.value, ): super().__init__(id, owner, tags, description, child=None) # Plan instance info self.name = name or self.__class__.__name__ self.role = role or Role() if role is None: for st in state_tensors: self.role.register_state_tensor(st) self.include_state = include_state self.is_building = False self.state_attributes = {} self.is_built = is_built self.torchscript = None self.input_types = input_types self.validate_input_types = True self.tracing = False self._base_framework = base_framework self.roles = roles or {base_framework: self.role} # The plan has not been sent so it has no reference to remote locations self.pointers = {} if not hasattr(self, "forward"): self.forward = forward_func or None self.__name__ = self.__repr__( ) # For PyTorch jit tracing compatibility # List of available translations self.translations = []
def test_reset(): role = Role() placeholder = PlaceHolder() target = torch.ones([1]) action = ("get", target, (), {}) role.register_action((action, placeholder), CommunicationAction) role.placeholders = {"ph_id1": PlaceHolder(), "ph_id2": PlaceHolder()} role.input_placeholder_ids = ("input1", "input2") role.output_placeholder_ids = ("output1", ) assert len(role.actions) == 1 assert len(role.placeholders) == 2 assert role.input_placeholder_ids == ("input1", "input2") assert role.output_placeholder_ids == ("output1", ) role.reset() assert len(role.actions) == 0 assert len(role.placeholders) == 0 assert role.input_placeholder_ids == () assert role.output_placeholder_ids == ()
def get_role_for_owner(self, owner): if owner.id not in self.roles: self.roles[owner.id] = Role() return self.roles[owner.id]