コード例 #1
0
ファイル: plan.py プロジェクト: theAkhileshRai/PySyft
class Plan(AbstractObject):
    """
    A Plan stores a sequence of actions, just like a function.

    A Plan is intended to store a sequence of actions, just like a function,
    but it allows to send this sequence of actions to remote workers and to keep a
    reference to it. This way, to compute remotely this sequence of actions on some remote
    input referenced through pointers, instead of sending multiple messages you need now to send a
    single message with the references of the plan and the pointers.

    All arguments are optional.

    Args:
        name: the name of the name
        state: store the plan tensors like model parameters
        include_state: if true, implies that the plan is a function, else a class. If true, the
            state is re-integrated in the args to be accessed within the function
        is_built: state if the plan has already been built.
        placeholders: dict of placeholders used in the plan
        actions: list of commands (called actions)
        forward_func: the function to be transformed into a plan
        state_tensors: a tuple of state elements. It can be used to populate a state
        id: plan id
        owner: plan owner
        tags: plan tags
        description: plan description
    """

    _build_translators = []
    _wrapped_frameworks = {}

    def __init__(
        self,
        name: str = None,
        include_state: bool = False,
        is_built: bool = False,
        forward_func=None,
        state_tensors=[],
        role: Role = None,
        # General kwargs
        id: Union[str, int] = None,
        owner: "sy.workers.BaseWorker" = None,
        tags: List[str] = None,
        input_types: list = None,
        description: str = None,
    ):
        AbstractObject.__init__(self, id, owner, tags, description, child=None)

        # Plan instance info
        self.name = name or self.__class__.__name__

        self.role = role or Role()

        if role is None:
            for st in state_tensors:
                self.role.register_state_tensor(st)

        self.include_state = include_state
        self.is_building = False
        self.state_attributes = {}
        self.is_built = is_built
        self.torchscript = None
        self.input_types = input_types
        self.tracing = False

        # The plan has not been sent so it has no reference to remote locations
        self.pointers = dict()

        if not hasattr(self, "forward"):
            self.forward = forward_func or None

        self.__name__ = self.__repr__(
        )  # For PyTorch jit tracing compatibility

        # List of available translations
        self.translations = []

    @property
    def state(self):
        return self.role.state

    @property
    def actions(self):
        return self.role.actions

    def parameters(self):
        """
        This is defined to match the torch api of nn.Module where .parameters() return the model tensors / parameters
        """
        if self.state is not None:
            return self.state.tensors()
        else:
            return []

    def build(self, *args, trace_autograd=False):
        """Builds the plan.

        First, run the function to be converted in a plan in a context which
        activates the tracing and record the actions in trace.logs

        Second, store the result ids temporarily to helper ordering the output
        placeholders at return time

        Third, loop through the trace logs and replace the tensors found in the
        actions logged by PlaceHolders. Record those actions in
        plan.actions

        Args:
            args: Input arguments to run the plan
        """
        # Reset previous build
        self.role.reset()

        def build_nested_arg(arg, leaf_function):
            if isinstance(arg, list):
                return [build_nested_arg(obj, leaf_function) for obj in arg]
            elif isinstance(arg, tuple):
                return tuple(
                    [build_nested_arg(obj, leaf_function) for obj in arg])
            elif isinstance(arg, dict):
                return {
                    k: build_nested_arg(v, leaf_function)
                    for k, v in arg.items()
                }
            else:
                return leaf_function(arg)

        # Enable tracing
        self.toggle_tracing(True)
        self.is_building = True

        # typecheck
        self.input_types = NestedTypeWrapper(args)

        # Run once to build the plan
        if trace_autograd:
            # Wrap arguments that require gradients with AutogradTensor,
            # to be able to trace autograd operations
            args = build_nested_arg(
                args,
                lambda x: AutogradTensor().on(x, wrap=False)
                if isinstance(x, FrameworkTensor) and x.requires_grad else x,
            )
            # Add Placeholder after AutogradTensor in the chain
            # so that all operations that happen inside AutogradTensor are recorded by Placeholder
            args_placeholders = build_nested_arg(
                args,
                lambda x: PlaceHolder.insert(
                    x, AutogradTensor, role=self.role, tracing=True),
            )
        else:
            # Add Placeholder on top of each arg
            args = args_placeholders = build_nested_arg(
                args,
                lambda x: PlaceHolder.create_from(
                    x, role=self.role, tracing=True),
            )

        # Add state to args if needed
        if self.include_state:
            args += (self.state, )

        # Check the plan arguments to see what framework wrappers we might need to send to the plan
        framework_kwargs = {}

        forward_args = inspect.getfullargspec(self.forward).args
        for f_name, wrap_framework_func in Plan._wrapped_frameworks.items():
            if f_name in forward_args:
                framework_kwargs[f_name] = wrap_framework_func(
                    self.role, self.owner)

        results = self.forward(*args, **framework_kwargs)

        # Disable tracing
        self.toggle_tracing(False)
        self.is_building = False

        # Register inputs in role
        self.role.register_inputs(args_placeholders)

        # Register outputs in role
        if isinstance(results, (tuple, list)):
            results_placeholders = tuple(
                PlaceHolder.extract(result) for result in results)
        else:
            results_placeholders = PlaceHolder.extract(results)
        self.role.register_outputs(results_placeholders)

        self.is_built = True

        # Build registered translations
        for translator in Plan._build_translators:
            try:
                self.add_translation(translator)
                self.translations.append(translator)
            except:
                warnings.warn(
                    f"Failed to translate Plan with {translator.__name__}")

        return results

    def toggle_tracing(self, value=None):
        self.tracing = value if value is not None else not self.tracing
        self.state.tracing = self.tracing
        for ph in self.role.placeholders.values():
            ph.tracing = self.tracing

    def copy(self):
        """Creates a copy of a plan."""
        plan_copy = Plan(
            name=self.name,
            role=self.role.copy(),
            include_state=self.include_state,
            is_built=self.is_built,
            id=sy.ID_PROVIDER.pop(),
            owner=self.owner,
            tags=self.tags,
            input_types=self.input_types,
            description=self.description,
        )

        plan_copy.torchscript = self.torchscript

        return plan_copy

    def __setattr__(self, name, value):
        """Add new tensors or parameter attributes to the state and register them
        in the owner's registry
        """
        if isinstance(value, torch.jit.ScriptModule):
            object.__setattr__(self, name, value)
        elif isinstance(value, FrameworkTensor):
            self.role.register_state_tensor(value)
            self.state_attributes[name] = value
        elif isinstance(value, FrameworkLayerModule):
            for param in value.parameters():
                self.role.register_state_tensor(param)
            self.state_attributes[name] = value
        else:
            object.__setattr__(self, name, value)

    def __getattr__(self, name):
        if name not in self.state_attributes:
            raise AttributeError("State attribute not found.")

        value = self.state_attributes[name]
        if not self.is_building:
            return value

        if isinstance(value, FrameworkTensor):
            return self.role.placeholders[value.id]
        elif isinstance(value, FrameworkLayerModule):
            # We need to deepcopy here otherwise the real layer is modified when the Plan is being built
            copied_layer = copy.deepcopy(value)
            for copied_param, param in zip(copied_layer.named_parameters(),
                                           value.parameters()):
                (copied_name, _) = copied_param
                copied_layer._parameters[copied_name] = self.role.placeholders[
                    param.id]

            return copied_layer

    def __call__(self, *args):
        """
        Calls a plan execution with some arguments.

        When possible, run the original function to improve efficiency. When
        it's not, for example if you fetched the plan from a remote worker,
        then run it from the tape of actions:
        - Instantiate input placeholders
        - for each recorded action, run the action on the placeholders
          and use the result(s) to instantiate to appropriate placeholder.
        - Return the instantiation of all the output placeholders.
        """
        if self.forward is not None:
            if self.include_state:
                args = (*args, self.state)
            return self.forward(*args)
        else:
            self.input_types.input_check(self, args)
            self.role.instantiate_inputs(args)
            result = self.role.execute()
            if len(result) == 1:
                return result[0]
            return result

    def run(self, args_: Tuple, result_ids: List[Union[str, int]]):
        """Controls local or remote plan execution.
        If the plan doesn't have the plan built, first build it using the original function.

        Args:
            args_: Arguments used to run plan.
            result_ids: List of ids where the results will be stored.
        """
        # TODO: can we reuse result_ids?
        return self.__call__(*args_)

    def send(self, *locations: AbstractWorker) -> PointerPlan:
        """Send plan to locations.

        If the plan was not built locally it will raise an exception.
        If `force` = true plan is going to be sent either way.

        Args:
            locations: List of workers.
            force: A boolean indicating if this action should be forced.
        """
        if not self.is_built:
            raise RuntimeError(
                "A plan needs to be built before being sent to a worker.")

        if len(locations) == 1:
            location = locations[0]

            # Check if plan was already sent at the location
            if location in self.pointers:
                return self.pointers[location]

            # Send the Plan
            pointer = self.owner.send(self, workers=location)

            self.pointers[location] = pointer
        else:
            ids_at_location = []
            for location in locations:
                if location in self.pointers:
                    # Use the pointer that was already sent
                    pointer = self.pointers[location]
                else:
                    # Send the Plan
                    pointer = self.owner.send(self, workers=location)

                    self.pointers[location] = pointer

                ids_at_location.append(pointer.id_at_location)

            pointer = sy.PointerPlan(location=locations,
                                     id_at_location=ids_at_location)

        return pointer

    def get_args_shape(self):
        """Returns input tensors shapes"""
        if not self.is_built:
            raise RuntimeError(
                "A plan needs to be built before input shapes can be known.")

        return [ph.expected_shape for ph in self.role.input_placeholders()]

    @staticmethod
    def register_build_translator(translator: "AbstractPlanTranslator"):
        Plan._build_translators.append(translator)

    @staticmethod
    def register_framework(f_name, f_package):
        """
        When we use methods defined in a framework (like: torch.randn) we have a framework
        wrapper that helps as register and keep track of what methods are called
        With the below lines, we "register" what frameworks we have support to handle
        Args:
            f_name (String): framework name (eg. torch, crypten)
            f_package (imported module): imported library
        """
        def call_wrapped_framework(role, owner):
            return FrameworkWrapper(f_package, role, owner)

        Plan._wrapped_frameworks[f_name] = call_wrapped_framework

    def add_translation(self, plan_translator: "AbstractPlanTranslator"):
        return plan_translator(self).translate()

    def remove_translation(
            self,
            plan_translator: "AbstractPlanTranslator" = PlanTranslatorDefault):
        plan_translator(self).remove()
        return self

    def get_(self):
        self.state.get_()
        return self

    get = get_

    def get_pointers(self):
        return self.pointers

    def fix_precision_(self, *args, **kwargs):
        self.state.fix_precision_(*args, **kwargs)
        return self

    fix_precision = fix_prec_ = fix_prec = fix_precision_

    def float_precision_(self):
        self.state.float_precision_()
        return self

    float_precision = float_prec_ = float_prec = float_precision_

    def share_(self, *args, **kwargs):
        self.state.share_(*args, **kwargs)
        return self

    share = share_

    def create_pointer(self,
                       owner,
                       garbage_collect_data,
                       location=None,
                       id_at_location=None,
                       tags=None,
                       **kwargs):
        """
        Create a pointer to the plan

        Args:
            owner: the owner of the pointer
            garbage_collect_data: if true, when the pointer is deleted, the remote target is garbaged collected
            location: the location of the pointer
            id_at_location: the remote id at location
            tags: the tags inherited from the Plan

        Returns:
            PointerPlan: pointer to the plan
        """
        return PointerPlan(
            owner=owner,
            location=location or self.owner,
            id_at_location=id_at_location or self.id,
            garbage_collect_data=garbage_collect_data,
            tags=tags,
        )

    def __str__(self):
        """Returns the string representation of Plan."""
        out = "<"
        out += str(type(self)).split("'")[1].split(".")[-1]
        out += " " + str(self.name)
        out += " id:" + str(self.id)
        out += " owner:" + str(self.owner.id)

        if self.tags is not None and len(self.tags):
            out += " Tags:"
            for tag in self.tags:
                out += " " + str(tag)

        if self.is_built:
            out += " built"

        out += ">"
        out += "\n"
        _self = self

        # out += f"def {self.name}("
        # out += ", ".join(f"arg_{extract_tag(p)}" for p in self.find_placeholders("input"))
        # out += "):\n"
        # for action in self.actions:
        #     line = "    "
        #     if action.return_ids is not None:
        #         if isinstance(action.return_ids, PlaceHolder):
        #             tag = extract_tag(action.return_ids)
        #             line += f"_{tag} = "
        #         elif isinstance(action.return_ids, tuple):
        #             line += (
        #                 ", ".join(
        #                     f"_{extract_tag(o)}" if isinstance(o, PlaceHolder) else str(o)
        #                     for o in action.return_ids
        #                 )
        #                 + " = "
        #             )
        #         else:
        #             line += str(action.return_ids) + " = "
        #     if action.target is not None:
        #         line += f"_{extract_tag(self.placeholders[action.target.value])}."
        #     line += action.name + "("
        #     line += ", ".join(
        #         f"_{extract_tag(arg)}" if isinstance(arg, PlaceHolder) else str(arg)
        #         for arg in action.args
        #     )
        #     if action.kwargs:
        #         line += ", " + ", ".join(f"{k}={w}" for k, w in action.kwargs.items())
        #     line += ")\n"
        #     out += line

        # out += "    return "
        # out += ", ".join(f"_{extract_tag(p)}" for p in self.find_placeholders("output"))

        return out

    def __repr__(self):
        return self.__str__()

    @staticmethod
    def replace_non_instanciated_placeholders(plan: "Plan") -> "Plan":
        # Replace non-instanciated placeholders from plan.placeholders by instanciated placeholders
        # from state.state_placeholders
        # NOTE Maybe state shouldn't contain instanciated placeholders but values directly?
        state_placeholders = {
            ph.id.value: ph
            for ph in plan.state.state_placeholders
        }
        plan.placeholders = {**plan.placeholders, **state_placeholders}

        return plan

    @staticmethod
    def simplify(worker: AbstractWorker, plan: "Plan") -> tuple:
        """
        This function takes the attributes of a Plan and saves them in a tuple
        Args:
            worker (AbstractWorker): the worker doing the serialization
            plan (Plan): a Plan object
        Returns:
            tuple: a tuple holding the unique attributes of the Plan object

        """
        if not plan.is_built:
            raise RuntimeError(
                "A Plan needs to be built before being serialized.")

        return (
            sy.serde.msgpack.serde._simplify(worker, plan.id),
            sy.serde.msgpack.serde._simplify(worker, plan.role),
            sy.serde.msgpack.serde._simplify(worker, plan.include_state),
            sy.serde.msgpack.serde._simplify(worker, plan.name),
            sy.serde.msgpack.serde._simplify(worker, plan.tags),
            sy.serde.msgpack.serde._simplify(worker, plan.description),
            sy.serde.msgpack.serde._simplify(worker, plan.torchscript),
            sy.serde.msgpack.serde._simplify(worker, plan.input_types),
        )

    @staticmethod
    def detail(worker: AbstractWorker, plan_tuple: tuple) -> "Plan":
        """This function reconstructs a Plan object given its attributes in the form of a tuple.
        Args:
            worker: the worker doing the deserialization
            plan_tuple: a tuple holding the attributes of the Plan
        Returns:
            plan: a Plan object
        """
        (
            id_,
            role,
            include_state,
            name,
            tags,
            description,
            torchscript,
            input_types,
        ) = plan_tuple

        id_ = sy.serde.msgpack.serde._detail(worker, id_)
        role = sy.serde.msgpack.serde._detail(worker, role)
        name = sy.serde.msgpack.serde._detail(worker, name)
        tags = sy.serde.msgpack.serde._detail(worker, tags)
        description = sy.serde.msgpack.serde._detail(worker, description)
        torchscript = sy.serde.msgpack.serde._detail(worker, torchscript)
        input_types = sy.serde.msgpack.serde._detail(worker, input_types)

        plan = sy.Plan(
            role=role,
            include_state=include_state,
            is_built=True,
            id=id_,
            owner=worker,
            name=name,
            tags=tags,
            description=description,
            input_types=input_types,
        )

        plan.torchscript = torchscript

        return plan

    @staticmethod
    def bufferize(worker: AbstractWorker, plan: "Plan") -> PlanPB:
        """
        This function takes the attributes of a Plan and saves them in a Protobuf message
        Args:
            worker (AbstractWorker): the worker doing the serialization
            plan (Plan): a Plan object
        Returns:
            PlanPB: a Protobuf message holding the unique attributes of the Plan object
        """
        if not plan.is_built:
            raise RuntimeError(
                "A Plan needs to be built before being serialized.")

        protobuf_plan = PlanPB()

        sy.serde.protobuf.proto.set_protobuf_id(protobuf_plan.id, plan.id)

        protobuf_plan.role.CopyFrom(
            sy.serde.protobuf.serde._bufferize(worker, plan.role))

        protobuf_plan.include_state = plan.include_state
        protobuf_plan.name = plan.name
        protobuf_plan.tags.extend(plan.tags)

        if protobuf_plan.description:
            protobuf_plan.description = plan.description

        if plan.torchscript:
            protobuf_plan.torchscript = plan.torchscript.save_to_buffer()

        if plan.input_types:
            input_types = sy.serde.protobuf.serde._bufferize(
                worker, plan.input_types)
            protobuf_plan.input_types.CopyFrom(input_types)

        return protobuf_plan

    @staticmethod
    def unbufferize(worker: AbstractWorker, protobuf_plan: PlanPB) -> "Plan":
        """This function reconstructs a Plan object given its attributes in the form of a Protobuf message
        Args:
            worker: the worker doing the deserialization
            protobuf_plan: a Protobuf message holding the attributes of the Plan
        Returns:
            plan: a Plan object
        """
        id_ = sy.serde.protobuf.proto.get_protobuf_id(protobuf_plan.id)

        role = sy.serde.protobuf.serde._unbufferize(worker, protobuf_plan.role)

        name = protobuf_plan.name
        tags = set(protobuf_plan.tags) if protobuf_plan.tags else None
        description = protobuf_plan.description if protobuf_plan.description else None
        input_types = sy.serde.protobuf.serde._unbufferize(
            worker, protobuf_plan.input_types)

        plan = Plan(
            role=role,
            include_state=protobuf_plan.include_state,
            is_built=True,
            id=id_,
            owner=worker,
            name=name,
            tags=tags,
            description=description,
            input_types=input_types,
        )

        if protobuf_plan.torchscript:
            torchscript = io.BytesIO(protobuf_plan.torchscript)
            plan.torchscript = torch.jit.load(torchscript)

        return plan

    @property
    def code(self) -> str:
        """Returns string representation of Plan actions"""
        input_names = {
            id: f"arg_{i + 1}"
            for i, id in enumerate(self.role.input_placeholder_ids)
        }
        output_names = {
            id: f"out_{i + 1}"
            for i, id in enumerate(self.role.output_placeholder_ids)
        }
        state_names = {
            id: f"state_{i + 1}"
            for i, id in enumerate(self.role.state.state_placeholders)
        }
        var_names = {**input_names, **output_names, **state_names}

        out = f"def {self.name}("
        out += ", ".join(
            [var_names[id] for id in self.role.input_placeholder_ids])
        out += "):\n"
        for action in self.role.actions:
            out += f"    {action.code(var_names)}\n"

        out += "    return "
        out += ", ".join(
            [var_names[id] for id in self.role.output_placeholder_ids])

        return out

    @staticmethod
    def get_protobuf_schema() -> PlanPB:
        return PlanPB
コード例 #2
0
ファイル: plan.py プロジェクト: theAkhileshRai/PySyft
    def build(self, *args, trace_autograd=False):
        """Builds the plan.

        First, run the function to be converted in a plan in a context which
        activates the tracing and record the actions in trace.logs

        Second, store the result ids temporarily to helper ordering the output
        placeholders at return time

        Third, loop through the trace logs and replace the tensors found in the
        actions logged by PlaceHolders. Record those actions in
        plan.actions

        Args:
            args: Input arguments to run the plan
        """
        # Reset previous build
        self.role.reset()

        def build_nested_arg(arg, leaf_function):
            if isinstance(arg, list):
                return [build_nested_arg(obj, leaf_function) for obj in arg]
            elif isinstance(arg, tuple):
                return tuple(
                    [build_nested_arg(obj, leaf_function) for obj in arg])
            elif isinstance(arg, dict):
                return {
                    k: build_nested_arg(v, leaf_function)
                    for k, v in arg.items()
                }
            else:
                return leaf_function(arg)

        # Enable tracing
        self.toggle_tracing(True)
        self.is_building = True

        # typecheck
        self.input_types = NestedTypeWrapper(args)

        # Run once to build the plan
        if trace_autograd:
            # Wrap arguments that require gradients with AutogradTensor,
            # to be able to trace autograd operations
            args = build_nested_arg(
                args,
                lambda x: AutogradTensor().on(x, wrap=False)
                if isinstance(x, FrameworkTensor) and x.requires_grad else x,
            )
            # Add Placeholder after AutogradTensor in the chain
            # so that all operations that happen inside AutogradTensor are recorded by Placeholder
            args_placeholders = build_nested_arg(
                args,
                lambda x: PlaceHolder.insert(
                    x, AutogradTensor, role=self.role, tracing=True),
            )
        else:
            # Add Placeholder on top of each arg
            args = args_placeholders = build_nested_arg(
                args,
                lambda x: PlaceHolder.create_from(
                    x, role=self.role, tracing=True),
            )

        # Add state to args if needed
        if self.include_state:
            args += (self.state, )

        # Check the plan arguments to see what framework wrappers we might need to send to the plan
        framework_kwargs = {}

        forward_args = inspect.getfullargspec(self.forward).args
        for f_name, wrap_framework_func in Plan._wrapped_frameworks.items():
            if f_name in forward_args:
                framework_kwargs[f_name] = wrap_framework_func(
                    self.role, self.owner)

        results = self.forward(*args, **framework_kwargs)

        # Disable tracing
        self.toggle_tracing(False)
        self.is_building = False

        # Register inputs in role
        self.role.register_inputs(args_placeholders)

        # Register outputs in role
        if isinstance(results, (tuple, list)):
            results_placeholders = tuple(
                PlaceHolder.extract(result) for result in results)
        else:
            results_placeholders = PlaceHolder.extract(results)
        self.role.register_outputs(results_placeholders)

        self.is_built = True

        # Build registered translations
        for translator in Plan._build_translators:
            try:
                self.add_translation(translator)
                self.translations.append(translator)
            except:
                warnings.warn(
                    f"Failed to translate Plan with {translator.__name__}")

        return results