Ejemplo n.º 1
0
    def __init__(
        self,
        name: str = None,
        procedure: Procedure = None,
        state: State = None,
        include_state: bool = False,
        is_built: bool = False,
        # Optional kwargs if commands or state are not provided
        state_ids: List[Union[str, int]] = None,
        arg_ids: List[Union[str, int]] = None,
        result_ids: List[Union[str, int]] = None,
        readable_plan: List = None,
        blueprint=None,
        state_tensors=None,
        # General kwargs
        id: Union[str, int] = None,
        owner: "sy.workers.BaseWorker" = None,
        tags: List[str] = None,
        description: str = None,
    ):
        owner = owner or sy.local_worker
        AbstractObject.__init__(self, id, owner, tags, description, child=None)
        ObjectStorage.__init__(self)

        # Plan instance info
        self.name = name or self.__class__.__name__
        self.owner = owner

        # If we have plans in plans we need to keep track of the states for each plan
        # because we will need to serialize and send them to the remote workers
        self.nested_states = []

        # Info about the plan stored via the state and the procedure
        self.procedure = procedure or Procedure(readable_plan, arg_ids, result_ids)
        self.state = state or State(owner=owner, plan=self, state_ids=state_ids)
        if state_tensors is not None:
            for tensor in state_tensors:
                self.state.state_ids.append(tensor.id)
                self.owner.register_obj(tensor)

        self.include_state = include_state
        self.is_built = is_built
        self.input_shapes = None
        self._output_shape = None

        # The plan has not been sent
        self.pointers = dict()

        if blueprint is not None:
            self.forward = blueprint
        elif self.is_built:
            self.forward = None
Ejemplo n.º 2
0
    def __init__(
        self,
        name: str = None,
        procedure: Procedure = None,
        state: State = None,
        include_state: bool = False,
        is_built: bool = False,
        # Optional kwargs if commands or state are not provided
        state_ids: List[Union[str, int]] = None,
        arg_ids: List[Union[str, int]] = None,
        result_ids: List[Union[str, int]] = None,
        readable_plan: List = None,
        blueprint=None,
        state_tensors=None,
        # General kwargs
        id: Union[str, int] = None,
        owner: "sy.workers.BaseWorker" = None,
        tags: List[str] = None,
        description: str = None,
    ):
        owner = owner or sy.local_worker
        AbstractObject.__init__(self, id, owner, tags, description, child=None)
        ObjectStorage.__init__(self)

        # Plan instance info
        self.name = name or self.__class__.__name__
        self.owner = owner

        # Info about the plan stored via the state and the procedure
        self.procedure = procedure or Procedure(readable_plan, arg_ids,
                                                result_ids)
        self.state = state or State(
            owner=owner, plan=self, state_ids=state_ids)
        if state_tensors is not None:
            for tensor in state_tensors:
                self.state.state_ids.append(tensor.id)
                self.owner.register_obj(tensor)
        self.include_state = include_state
        self.is_built = is_built

        if blueprint is not None:
            self.forward = blueprint
        elif self.is_built:
            self.forward = None
Ejemplo n.º 3
0
def test_procedure_update_ids():
    commands = [
        (
            31,
            (
                1,
                (
                    (
                        6,
                        (
                            (5, (b"__add__",)),
                            (23, (27674294093, 68519530406, "me", None, (10, (1,)), True)),
                            (6, ((23, (2843683950, 91383408771, "me", None, (10, (1,)), True)),)),
                            (0, ()),
                        ),
                    ),
                    (75165665059,),
                ),
            ),
        )
    ]

    procedure = Procedure(operations=commands, arg_ids=[68519530406], result_ids=(75165665059,))

    procedure.update_ids(
        from_ids=[27674294093], to_ids=[73570994542], from_worker="me", to_worker="alice"
    )

    assert procedure.operations == [
        (
            31,
            (
                1,
                (
                    (
                        6,
                        (
                            (5, (b"__add__",)),
                            (23, (73570994542, 68519530406, "alice", None, (10, (1,)), True)),
                            (
                                6,
                                ((23, (2843683950, 91383408771, "alice", None, (10, (1,)), True)),),
                            ),
                            (0, ()),
                        ),
                    ),
                    (75165665059,),
                ),
            ),
        )
    ]

    tensor = th.tensor([1.0])
    tensor_id = tensor.id
    procedure.update_args(args=(tensor,), result_ids=[8730174527])

    assert procedure.operations == [
        (
            31,
            (
                1,
                (
                    (
                        6,
                        (
                            (5, (b"__add__",)),
                            (23, (73570994542, tensor_id, "alice", None, (10, (1,)), True)),
                            (
                                6,
                                ((23, (2843683950, 91383408771, "alice", None, (10, (1,)), True)),),
                            ),
                            (0, ()),
                        ),
                    ),
                    (8730174527,),
                ),
            ),
        )
    ]

    procedure.operations = [
        (73570994542, 8730174527, b"alice", None, (10, (1,)), True),
        (2843683950, 91383408771, "alice", None, (10, (1,)), True),
    ]

    procedure.update_worker_ids(from_worker_id="alice", to_worker_id="me")

    assert procedure.operations == [
        (73570994542, 8730174527, "me", None, (10, (1,)), True),
        (2843683950, 91383408771, "me", None, (10, (1,)), True),
    ]