Exemplo n.º 1
0
        def new_backward(self, *args, **kwargs):
            worker = self.owner
            # Retrieve all the variable ids involved in the computation graph
            variable_ids = torch_utils.get_connected_variables(self)
            variable_ids = [
                var_id for var_id in variable_ids if var_id in worker._objects
            ]
            # Save all the gradients (to keep the id) and reset the grads
            saved_grads = {}
            for variable_id in variable_ids:
                syft_tensor = worker.get_obj(variable_id)
                var = syft_tensor.parent
                assert var.id == variable_id
                saved_grads[variable_id] = var.grad
                var.grad = None

            # Performs the backward
            self.native_native_backward(*args, **kwargs)

            # Put back the original grad envelop and insert the new grad value in it
            for variable_id in variable_ids:
                syft_tensor = worker.get_obj(variable_id)
                # retrieve the var to fix
                var = syft_tensor.parent
                # retrieve the old grad, and insert it (to keep the chain) [first the envelope, then the data]
                saved_grad = saved_grads[variable_id]
                if saved_grad is not None:
                    # store the computed gradient
                    computed_grad = var.grad
                    var.assign_grad_(saved_grad)
                    # Insert the value of the computed_grad
                    if computed_grad is not None:
                        var.grad.data.native_set_(computed_grad.data)
                # Make sure everyone has the right owner
                torch_utils.enforce_owner(var, worker)
Exemplo n.º 2
0
    def _execute_numpy_call(self, attr, self_, *args, **kwargs):
        """Transmit the call to the appropriate TensorType for handling."""

        # Distinguish between a command with torch tensors (like when called by the client,
        # or received from another worker), and a command with syft tensor, which can occur
        # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor
        # overloads add and replace it by sub)
        # try:
        #     torch_utils.assert_has_only_torch_tensorvars((args, kwargs))
        #     is_torch_command = True
        # except AssertionError:
        is_torch_command = False

        has_self = self_ is not None

        # if has_self:
        #     command = torch._command_guard(attr, 'tensorvar_methods')
        # else:
        #     command = torch._command_guard(attr, 'torch_modules')
        command = attr

        raw_command = {
            "command": command,
            "has_self": has_self,
            "args": args,
            "kwargs": kwargs,
        }
        if has_self:
            raw_command["self"] = self_

        # if is_torch_command:
        #     # Unwrap the torch wrapper
        #     syft_command, child_type = torch_utils.prepare_child_command(
        #         raw_command, replace_tensorvar_with_child=True)
        # else:
        #     # Get the next syft class
        #     # The actual syft class is the one which redirected (see the  _PlusIsMinus ex.)
        #     syft_command, child_type = torch_utils.prepare_child_command(
        #         raw_command, replace_tensorvar_with_child=True)
        #
        #     torch_utils.assert_has_only_syft_tensors(syft_command)

        # Note: because we have pb of registration of tensors with the right worker,
        # and because having Virtual workers creates even more ambiguity, we specify the worker
        # performing the operation
        # torch_utils.enforce_owner(raw_command, self)

        result = sy.array.handle_call(raw_command, owner=self)

        torch_utils.enforce_owner(result, self)

        if is_torch_command:
            # Wrap the result
            if has_self and utils.is_in_place_method(attr):
                result = torch_utils.bind_tensor_nodes(raw_command["self"],
                                                       result)
            else:
                result = torch_utils.wrap_command(result)

        return result
Exemplo n.º 3
0
    def _execute_call(self, attr, self_, *args, **kwargs):
        """Transmit the call to the appropriate TensorType for handling."""

        # if this is none - then it means that self_ is not a torch wrapper
        # and we need to execute one level higher TODO: not ok for complex args
        if self_ is not None and self_.child is None:
            new_args = [
                arg.wrap(True) if not isinstance(arg, int) else arg for arg in args
            ]
            return self._execute_call(attr, self_.wrap(True), *new_args, **kwargs)

        # Distinguish between a command with torch tensors (like when called by the client,
        # or received from another worker), and a command with syft tensor, which can occur
        # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor
        # overloads add and replace it by sub)
        try:
            torch_utils.assert_has_only_torch_tensorvars((args, kwargs))
            is_torch_command = True
        except AssertionError:
            is_torch_command = False

        has_self = self_ is not None

        if has_self:
            command = torch._command_guard(attr, "tensorvar_methods")
        else:
            command = torch._command_guard(attr, "torch_modules")

        raw_command = {
            "command": command,
            "has_self": has_self,
            "args": args,
            "kwargs": kwargs,
        }
        if has_self:
            raw_command["self"] = self_
        if is_torch_command:
            # Unwrap the torch wrapper
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True
            )
        else:
            # Get the next syft class
            # The actual syft class is the one which redirected (see the  _PlusIsMinus ex.)
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True
            )

            # torch_utils.assert_has_only_syft_tensors(syft_command)

        # Note: because we have pb of registration of tensors with the right worker,
        # and because having Virtual workers creates even more ambiguity, we specify the worker
        # performing the operation

        result = child_type.handle_call(syft_command, owner=self)

        if is_torch_command:
            # Wrap the result
            if has_self and utils.is_in_place_method(attr):
                # TODO: fix this properly: don't wrap the same way if syft or Variable
                if torch_utils.is_variable(result) or torch_utils.is_tensor(result):
                    wrapper = torch_utils.bind_tensor_nodes(
                        raw_command["self"], result.child
                    )
                else:
                    wrapper = torch_utils.bind_tensor_nodes(raw_command["self"], result)
            else:
                wrapper = torch_utils.wrap_command(result)
            torch_utils.enforce_owner(wrapper, self)
            return wrapper
        else:
            # We don't need to wrap
            torch_utils.enforce_owner(result, self)
            return result