Beispiel #1
0
 def register(self, result):
     """Register an object with SyftTensors."""
     if issubclass(type(result), sy._SyftTensor):
         syft_obj = result
         self.register_object(syft_obj)
     elif torch_utils.is_tensor(result):
         tensor = result
         self.register_object(tensor.child)
     elif torch_utils.is_variable(result):
         variable = result
         self.register(variable.child)
         self.register(variable.data.child)
         if not hasattr(variable, "grad") or variable.grad is None:
             variable.init_grad_()
         self.register(variable.grad.child)
         self.register(variable.grad.data.child)
     # Case of a iter type non json serializable
     elif isinstance(result, (list, tuple, set, bytearray, range)):
         for res in result:
             self.register(res)
     elif result is None:
         """do nothing."""
     elif isinstance(result, np.ndarray):
         self.register_object(result)
     else:
         raise TypeError("The type", type(result), "is not supported at the moment")
     return
Beispiel #2
0
 def python_encode(self, obj, private_local):
     # Case of basic types
     if isinstance(obj, (int, float, str)) or obj is None:
         return obj
     elif isinstance(obj, type(...)):
         return "..."
     elif isinstance(obj, np.ndarray):
         return obj.ser(private=private_local, to_json=False)
     # Dict
     elif isinstance(obj, dict):
         return {
             k: self.python_encode(v, private_local)
             for k, v in obj.items()
         }
     # Variable
     elif torch_utils.is_variable(obj):
         tail_object = torch_utils.find_tail_of_chain(obj)
         if self.retrieve_pointers and isinstance(tail_object,
                                                  sy._PointerTensor):
             self.found_pointers.append(tail_object)
         return obj.ser(private=private_local, is_head=True)
     # Tensors
     elif torch_utils.is_tensor(obj):
         tail_object = torch_utils.find_tail_of_chain(obj)
         if self.retrieve_pointers and isinstance(tail_object,
                                                  sy._PointerTensor):
             self.found_pointers.append(tail_object)
         return obj.ser(private=private_local)
     # sy._SyftTensor (Pointer, Local)
     # [Note: shouldn't be called on regular chain with end=tensorvar]
     elif torch_utils.is_syft_tensor(obj):
         tail_object = torch_utils.find_tail_of_chain(obj)
         if self.retrieve_pointers and isinstance(tail_object,
                                                  sy._PointerTensor):
             self.found_pointers.append(tail_object)
         return obj.ser(private=private_local)
     # List
     elif isinstance(obj, list):
         return [self.python_encode(i, private_local) for i in obj]
     # np array
     elif isinstance(obj, np.ndarray):
         return obj.ser(private=private_local, to_json=False)
     # Iterables non json-serializable
     elif isinstance(obj, (tuple, set, bytearray, range)):
         key = get_serialized_key(obj)
         return {key: [self.python_encode(i, private_local) for i in obj]}
     # Slice
     elif isinstance(obj, slice):
         key = get_serialized_key(obj)
         return {key: {'args': [obj.start, obj.stop, obj.step]}}
     # Generator
     elif isinstance(obj, types.GeneratorType):
         logging.warning("Generator args can't be transmitted")
         return []
     # worker
     elif isinstance(obj, (sy.SocketWorker, sy.VirtualWorker)):
         return {'__worker__': obj.id}
     # Else log the error
     else:
         raise ValueError('Unhandled type', type(obj))
Beispiel #3
0
    def de_register_object(self, obj, _recurse_torch_objs=True):
        """Unregisters an object and removes attributes which are indicative of
        registration.

        Note that the way in which attributes are deleted has been
        informed by this StackOverflow post: https://goo.gl/CBEKLK
        """

        is_torch_tensor = torch_utils.is_tensor(obj)

        if not is_torch_tensor:
            if hasattr(obj, "id"):
                self.rm_obj(obj.id)
                del obj.id
            if hasattr(obj, "owner"):
                del obj.owner

        if hasattr(obj, "child"):
            if obj.child is not None:
                if is_torch_tensor:
                    if _recurse_torch_objs:
                        self.de_register_object(obj.child, _recurse_torch_objs=False)
                else:
                    self.de_register_object(
                        obj.child, _recurse_torch_objs=_recurse_torch_objs
                    )
            if not is_torch_tensor:
                delattr(obj, "child")
Beispiel #4
0
 def python_encode(self, obj, private_local):
     # /!\ Sort by frequency
     # Dict
     if isinstance(obj, dict):
         return {
             k: self.python_encode(v, private_local)
             for k, v in obj.items()
         }
     # sy._SyftTensor (Pointer, Local)
     elif issubclass(type(obj), sy._SyftTensor):
         tail_object = torch_utils.find_tail_of_chain(obj)
         if self.retrieve_pointers and isinstance(tail_object,
                                                  sy._PointerTensor):
             self.found_pointers.append(tail_object)
         return obj.ser(private=private_local)
     # Case of basic types
     elif isinstance(obj, (int, float, str)) or obj is None:
         return obj
     # List
     elif isinstance(obj, list):
         return [self.python_encode(i, private_local) for i in obj]
     # Iterables non json-serializable
     elif isinstance(obj, (tuple, set, bytearray, range)):
         key = get_serialized_key(obj)
         return {key: [self.python_encode(i, private_local) for i in obj]}
     # Variable
     elif torch_utils.is_variable(obj):
         tail_object = torch_utils.find_tail_of_chain(obj)
         if self.retrieve_pointers and isinstance(tail_object,
                                                  sy._PointerTensor):
             self.found_pointers.append(tail_object)
         return obj.ser(private=private_local, is_head=True)
     # Tensors
     elif torch_utils.is_tensor(obj):
         tail_object = torch_utils.find_tail_of_chain(obj)
         if self.retrieve_pointers and isinstance(tail_object,
                                                  sy._PointerTensor):
             self.found_pointers.append(tail_object)
         return obj.ser(private=private_local)
     # Ellipsis
     elif isinstance(obj, type(...)):
         return "..."
     # np.array
     elif isinstance(obj, np.ndarray):
         return obj.ser(private=private_local, to_json=False)
     # Slice
     elif isinstance(obj, slice):
         key = get_serialized_key(obj)
         return {key: {"args": [obj.start, obj.stop, obj.step]}}
     # Generator
     elif isinstance(obj, types.GeneratorType):
         logging.warning("Generator args can't be transmitted")
         return []
     # worker
     elif isinstance(obj, (sy.SocketWorker, sy.VirtualWorker)):
         return {"__worker__": obj.id}
     # Else log the error
     else:
         raise ValueError("Unhandled type", type(obj))
Beispiel #5
0
        def _execute_method_call(self, *args, **kwargs):
            worker = hook_self.local_worker
            try:
                return worker._execute_call(attr, self, *args, **kwargs)

            except NotImplementedError:

                result = _execute_method_call(self.child, *args, **kwargs)
                if not torch_utils.is_tensor(self):
                    result = type(self)(result)
                    if hasattr(result, "second_constructor"):
                        result = result.second_constructor()
                    return result
                else:
                    return result
Beispiel #6
0
    def de_register(self, obj):

        """Un-register an object and its attribute."""
        if issubclass(type(obj), sy._SyftTensor):
            self.rm_obj(obj.id)
            # TODO: rm .data, .grad, .grad.data if any
        elif torch_utils.is_tensor(obj):
            self.de_register(obj.child)
        elif torch_utils.is_variable(obj):
            self.de_register(obj.child)
            self.de_register(obj.data.child)
            # TODO: rm .grad, .grad.data if any
        # Case of a iter type non json serializable
        elif isinstance(obj, (list, tuple, set, bytearray, range)):
            for o in obj:
                self.de_register(o)
        elif obj is None:
            """do nothing."""
        elif isinstance(obj, np.ndarray):
            self.rm_obj(obj.id)
        else:
            raise TypeError("The type", type(obj), "is not supported at the moment")
        return
Beispiel #7
0
    def share(self, *workers):
        for k, tensor in self._the_plot.items():
            self._the_plot[k] = tensor.long().share(*workers)

        self._board.board.share(*workers)

        self._board.layered_board.share(*workers)

        for k, tensor in self._board.layers.items():
            if (not isinstance(tensor.child, sy._SNNTensor)):
                self._board.layers[k] = tensor.long().share(*workers)

        for a in self._update_groups:
            for b in a:
                for c in b:
                    if (not isinstance(c, str)):
                        for key, item in c.__dict__.items():
                            if (utils.is_tensor(item)):
                                if (not isinstance(item.child, sy._SNNTensor)):
                                    c.__dict__[key] = item.long().share(
                                        *workers)

        self._backdrop.curtain.share(*workers)
Beispiel #8
0
    def send(self, location):
        for k, tensor in self._the_plot.items():
            self._the_plot[k] = tensor.send(location)

        self._board.board.send(location)

        self._board.layered_board.send(location)

        for k, tensor in self._board.layers.items():
            if (not isinstance(tensor.child, sy._PointerTensor)):
                tensor.send(location)

        for a in self._update_groups:
            for b in a:
                for c in b:
                    if (not isinstance(c, str)):
                        for item in c.__dict__.values():
                            if (utils.is_tensor(item)):
                                if (not isinstance(item.child,
                                                   sy._PointerTensor)):
                                    item.send(location)

        self._backdrop.curtain.send(location)
Beispiel #9
0
    def _execute_call(self, attr, self_, *args, **kwargs):
        """Transmit the call to the appropriate TensorType for handling."""

        # if this is none - then it means that self_ is not a torch wrapper
        # and we need to execute one level higher TODO: not ok for complex args
        if self_ is not None and self_.child is None:
            new_args = [
                arg.wrap(True) if not isinstance(arg, int) else arg for arg in args
            ]
            return self._execute_call(attr, self_.wrap(True), *new_args, **kwargs)

        # Distinguish between a command with torch tensors (like when called by the client,
        # or received from another worker), and a command with syft tensor, which can occur
        # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor
        # overloads add and replace it by sub)
        try:
            torch_utils.assert_has_only_torch_tensorvars((args, kwargs))
            is_torch_command = True
        except AssertionError:
            is_torch_command = False

        has_self = self_ is not None

        if has_self:
            command = torch._command_guard(attr, "tensorvar_methods")
        else:
            command = torch._command_guard(attr, "torch_modules")

        raw_command = {
            "command": command,
            "has_self": has_self,
            "args": args,
            "kwargs": kwargs,
        }
        if has_self:
            raw_command["self"] = self_
        if is_torch_command:
            # Unwrap the torch wrapper
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True
            )
        else:
            # Get the next syft class
            # The actual syft class is the one which redirected (see the  _PlusIsMinus ex.)
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True
            )

            # torch_utils.assert_has_only_syft_tensors(syft_command)

        # Note: because we have pb of registration of tensors with the right worker,
        # and because having Virtual workers creates even more ambiguity, we specify the worker
        # performing the operation

        result = child_type.handle_call(syft_command, owner=self)

        if is_torch_command:
            # Wrap the result
            if has_self and utils.is_in_place_method(attr):
                # TODO: fix this properly: don't wrap the same way if syft or Variable
                if torch_utils.is_variable(result) or torch_utils.is_tensor(result):
                    wrapper = torch_utils.bind_tensor_nodes(
                        raw_command["self"], result.child
                    )
                else:
                    wrapper = torch_utils.bind_tensor_nodes(raw_command["self"], result)
            else:
                wrapper = torch_utils.wrap_command(result)
            torch_utils.enforce_owner(wrapper, self)
            return wrapper
        else:
            # We don't need to wrap
            torch_utils.enforce_owner(result, self)
            return result
Beispiel #10
0
    def python_decode(self, dct):
        """
            Is called on every dict found. We check if some keys correspond
            to special keywords referring to a type we need to re-cast
            (e.g. tuple, or torch Variable).

        """

        # PLAN A: See if the dct object is not actually a dictionary and address
        # each case.

        if isinstance(dct, (int, str, float)):
            # a very strange special case
            if (dct == '...'):
                return ...
            return dct
        if isinstance(dct, (list, )):
            return [self.python_decode(o) for o in dct]
        if dct is None:
            return None
        if not isinstance(dct, dict):
            print(type(dct))
            raise TypeError('Type not handled', dct)

        # PLAN B: If the dct object IS a dictionary, check to see if it has a "type" key

        if ('type' in dct):
            if dct['type'] == "numpy.array":

                # at first glance, the following if statement might seem a bit confusing
                # since the dct object is identical for both. Basically, the pointer object
                # is created here (on the receiving end of a message) as opposed to on the sending
                # side. We decide whether to use the dictionary to construct a pointer or the
                # actual tensor based on wehther self.acquire is true. Note that this changes
                # how dct['id'] is used. If creating an actual tensor, the tensor id is set to dct['id]
                # otherwise, id_at_location is set to be dct['id']. Similarly with dct['owner'].

                # if we intend to receive the tensor itself, construct an array
                if (self.acquire):
                    return array(dct['data'], id=dct['id'], owner=self.worker)

                # if we intend to create a pointer, construct a pointer. Note that
                else:
                    return array_ptr(dct['data'],
                                     owner=self.worker,
                                     location=self.worker.get_worker(
                                         dct['owner']),
                                     id_at_location=dct['id'])
            elif dct['type'] == 'numpy.array_ptr':
                return self.worker.get_obj(dct['id_at_location'])

        # Plan C: As a last resort, use a Regex to try to find a type somewhere.
        # TODO: Plan C should never be called - but is used extensively in PySyft's PyTorch integratio

        pat = re.compile('__(.+)__')
        for key, obj in dct.items():
            if pat.search(key) is not None:
                obj_type = pat.search(key).group(1)
                # Case of a tensor
                if torch_utils.is_tensor(obj_type):
                    o = torch.guard['syft.' + obj_type].deser({key: obj},
                                                              self.worker,
                                                              self.acquire)
                    return o
                # Case of a Variable
                elif torch_utils.is_variable(obj_type):
                    return sy.Variable.deser({key: obj},
                                             self.worker,
                                             self.acquire,
                                             is_head=True)
                # Case of a Syft tensor
                elif torch_utils.is_syft_tensor(obj_type):
                    return sy._SyftTensor.deser_routing({key: obj},
                                                        self.worker,
                                                        self.acquire)
                # Case of a iter type non json serializable
                elif obj_type in ('tuple', 'set', 'bytearray', 'range'):
                    return eval(obj_type)([self.python_decode(o) for o in obj])
                # Case of a slice
                elif obj_type == 'slice':
                    return slice(*obj['args'])
                # Case of a worker
                elif obj_type == 'worker':
                    return self.worker.get_worker(obj)
                else:
                    raise TypeError('The special object type', obj_type,
                                    'is not supported')
            else:
                dct[key] = self.python_decode(obj)
        return dct