Beispiel #1
0
    def __init__(
        self,
        id: Union[str, int] = None,
        worker: AbstractWorker = None,
        state: State = None,
        actions: List[Action] = None,
        placeholders: Dict[Union[str, int], PlaceHolder] = None,
        input_placeholder_ids: Tuple[int, str] = None,
        output_placeholder_ids: Tuple[int, str] = None,
    ):
        self.id = id or sy.ID_PROVIDER.pop()
        self.worker = worker or sy.local_worker

        self.actions = actions or []

        # All placeholders
        self.placeholders = placeholders or {}
        # Input placeholders, stored by id
        self.input_placeholder_ids = input_placeholder_ids or ()
        # Output placeholders
        self.output_placeholder_ids = output_placeholder_ids or ()

        self.state = state or State()
        self.tracing = False

        for name, package in framework_packages.items():
            tracing_wrapper = FrameworkWrapper(package=package, role=self)
            setattr(self, name, tracing_wrapper)
Beispiel #2
0
    def __init__(
        self,
        name: str = None,
        include_state: bool = False,
        is_built: bool = False,
        forward_func=None,
        state_tensors=[],
        role: Role = None,
        # General kwargs
        id: Union[str, int] = None,
        owner: "sy.workers.BaseWorker" = None,
        tags: List[str] = None,
        input_types: list = None,
        description: str = None,
    ):
        AbstractObject.__init__(self, id, owner, tags, description, child=None)

        # Plan instance info
        self.name = name or self.__class__.__name__

        self.role = role or Role()

        if role is None:
            for st in state_tensors:
                self.role.register_state_tensor(st)

        self.include_state = include_state
        self.is_building = False
        self.state_attributes = {}
        self.is_built = is_built
        self.torchscript = None
        self.input_types = input_types
        self.tracing = False

        # The plan has not been sent so it has no reference to remote locations
        self.pointers = dict()

        if not hasattr(self, "forward"):
            self.forward = forward_func or None
        """
        When we use methods defined in a framework (like: torch.randn) we have a framework
        wrapper that helps as register and keep track of what methods are called
        With the below lines, we "register" what frameworks we have support to handle
        """
        self.wrapped_framework = {}
        for f_name, f_packages in framework_packages.items():
            self.wrapped_framework[f_name] = FrameworkWrapper(
                f_packages, self.role, self.owner)

        self.__name__ = self.__repr__(
        )  # For PyTorch jit tracing compatibility

        # List of available translations
        self.translations = []
Beispiel #3
0
        state_names = {
            id: f"state_{i + 1}"
            for i, id in enumerate(self.role.state.state_placeholders)
        }
        var_names = {**input_names, **output_names, **state_names}

        out = f"def {self.name}("
        out += ", ".join(
            [var_names[id] for id in self.role.input_placeholder_ids])
        out += "):\n"
        for action in self.role.actions:
            out += f"    {action.code(var_names)}\n"

        out += "    return "
        out += ", ".join(
            [var_names[id] for id in self.role.output_placeholder_ids])

        return out

    @staticmethod
    def get_protobuf_schema() -> PlanPB:
        return PlanPB


# Auto-register Plan build-time translations
Plan.register_build_translator(PlanTranslatorTorchscript)

# Auto-register Plan build-time frameworks
for f_name, f_package in framework_packages.items():
    Plan.register_framework(f_name, f_package)