Esempio n. 1
0
    def __init__(self,
                 owner=None,
                 id=None,
                 tags: set = None,
                 description: str = None,
                 shape=None):
        """A PlaceHolder acts as a tensor but does nothing special. It can get
        "instantiated" when a real tensor is appended as a child attribute. It
        will send forward all the commands it receives to its child tensor.

        When you send a PlaceHolder, you don't sent the instantiated tensors.

        Args:
            owner: An optional BaseWorker object to specify the worker on which
                the tensor is located.
            id: An optional string or integer id of the PlaceHolder.
        """
        super().__init__(id=id,
                         owner=owner,
                         tags=tags,
                         description=description)

        if not isinstance(self.id, PlaceholderId):
            self.id = PlaceholderId(self.id)

        self.expected_shape = tuple(shape) if shape is not None else None
        self.child = None
Esempio n. 2
0
    def read(self):
        """
        Return state tensors that are from this plan specifically, but not those
        of plans including in this plan.
        If run while a plan is building, declare all the state tensors to the plan
        currently building.
        """
        # If there is a plan building, it is referenced in init_plan
        if self.owner.init_plan:
            parent_plan = self.owner.init_plan
            # to see if we are in a sub plan, we use state objects equality
            if parent_plan.state != self:
                # for all the placeholders in this sub plan, we report a copy of them
                # in the parent plan and notify their origin using the #inner tag
                for placeholder in self.state_placeholders:
                    placeholder = placeholder.copy()
                    placeholder.id = PlaceholderId(placeholder.child.id)
                    placeholder.tags = set()
                    placeholder.tag("#inner")
                    parent_plan.state.state_placeholders.append(placeholder)
                    parent_plan.role.placeholders[
                        placeholder.child.id] = placeholder

        tensors = []
        for placeholder in self.state_placeholders:
            # State elements from sub plan should not be reported when read() is used
            if "#inner" not in placeholder.tags:
                tensor = placeholder.child
                tensors.append(tensor)
        return tensors
Esempio n. 3
0
    def read(self):
        """
        Return state tensors that are from this plan specifically, but not those
        of plans including in this plan.
        If run while a plan is building, declare all the state tensors to the plan
        currently building.
        """
        # TODO clean this function
        # If there is a plan building, it is referenced in init_plan
        if self.owner.init_plan:
            parent_plan = self.owner.init_plan
            # to see if we are in a sub plan, we use state objects equality
            if parent_plan.state != self:
                # for all the placeholders in this sub plan, we report a copy of them
                # in the parent plan and notify their origin using the #inner tag
                for placeholder in self.state_placeholders:
                    placeholder = placeholder.copy()
                    placeholder.id = PlaceholderId(placeholder.child.id)
                    placeholder.tags = set(("#inner", ))
                    parent_plan.state.state_placeholders.append(placeholder)
                    parent_plan.role.placeholders[
                        placeholder.child.id] = placeholder

        if self.read_placeholders or self.owner.init_plan:
            return [
                ph for ph in self.state_placeholders if "#inner" not in ph.tags
            ]
        else:
            return [
                ph.child for ph in self.state_placeholders
                if "#inner" not in ph.tags
            ]
Esempio n. 4
0
 def _replace_placeholder_ids(obj):
     if isinstance(obj, (tuple, list)):
         r = [_replace_placeholder_ids(o) for o in obj]
         return type(obj)(r)
     elif isinstance(obj, dict):
         return {key: _replace_placeholder_ids(value) for key, value in obj.items()}
     elif isinstance(obj, PlaceholderId):
         return PlaceholderId(old_ids_2_new_ids[obj.value])
     else:
         return obj
Esempio n. 5
0
    def add_placeholder(self, tensor, arg_ids, result_ids, node_type=None):
        """
        Create and register a new placeholder if not already existing (else return
        the existing one).

        The placeholder is tagged by a unique and incremental index for a given plan.

        Args:
            tensor: the tensor to replace with a placeholder
            node_type: Should be "input" or "output", used to tag like this: #<type>-*
        """
        if tensor.id not in self.placeholders:
            placeholder = sy.PlaceHolder(tags={f"#{self.var_count + 1}"},
                                         id=tensor.id,
                                         owner=self.owner)
            self.placeholders[tensor.id] = placeholder

            if node_type == "input":
                if tensor.id not in arg_ids:
                    raise ValueError(
                        f"The following tensor was used but is not known in "
                        f"this plan: \n{tensor}\nPossible reasons for this can be:\n"
                        f"- This tensor is external to the plan and should be provided "
                        f"using the state. See more about plan.state to fix this.\n"
                        f"- This tensor was created internally using torch.Tensor, "
                        f"torch.FloatTensor, torch.IntTensor, etc, which are not supported. "
                        f"Please use instead torch.tensor(..., dtype=torch.int32) for example."
                    )
                placeholder.tags.add(f"#input-{arg_ids.index(tensor.id)}")
                if tensor.id in result_ids:
                    placeholder.tags.add(
                        f"#output-{result_ids.index(tensor.id)}")

            elif node_type == "output":
                if tensor.id in result_ids:
                    placeholder.tags.add(
                        f"#output-{result_ids.index(tensor.id)}")

                if tensor.id in arg_ids:
                    placeholder.tags.add(
                        f"#input-{result_ids.index(tensor.id)}")
            else:
                raise ValueError("node_type should be 'input' or 'output'.")

            self.var_count += 1

        return PlaceholderId(tensor.id)
Esempio n. 6
0
    def copy(self):
        # TODO not the cleanest method ever
        placeholders = {}
        old_ids_2_new_ids = {}
        for ph in self.placeholders.values():
            copy = ph.copy()
            old_ids_2_new_ids[ph.id.value] = copy.id.value
            placeholders[copy.id.value] = copy

        new_input_placeholder_ids = tuple(
            old_ids_2_new_ids[self.placeholders[input_id].id.value]
            for input_id in self.input_placeholder_ids)
        new_output_placeholder_ids = tuple(
            old_ids_2_new_ids[self.placeholders[output_id].id.value]
            for output_id in self.output_placeholder_ids)

        state_placeholders = []
        for ph in self.state.state_placeholders:
            new_ph = PlaceHolder(
                id=old_ids_2_new_ids[ph.id.value]).instantiate(ph.child)
            state_placeholders.append(new_ph)

        state = State(state_placeholders)

        _replace_placeholder_ids = lambda obj: Role.nested_object_traversal(
            obj, lambda x: PlaceholderId(old_ids_2_new_ids[x.value]),
            PlaceholderId)

        new_actions = []
        for action in self.actions:
            action_type = type(action)
            target = _replace_placeholder_ids(action.target)
            args_ = _replace_placeholder_ids(action.args)
            kwargs_ = _replace_placeholder_ids(action.kwargs)
            return_ids = _replace_placeholder_ids(action.return_ids)
            new_actions.append(
                action_type(action.name, target, args_, kwargs_, return_ids))

        return Role(
            state=state,
            actions=new_actions,
            placeholders=placeholders,
            input_placeholder_ids=new_input_placeholder_ids,
            output_placeholder_ids=new_output_placeholder_ids,
            id=sy.ID_PROVIDER.pop(),
        )