def __init__( self, id: Union[str, int] = None, worker: AbstractWorker = None, state: State = None, actions: List[Action] = None, placeholders: Dict[Union[str, int], PlaceHolder] = None, input_placeholder_ids: Tuple[int, str] = None, output_placeholder_ids: Tuple[int, str] = None, ): self.id = id or sy.ID_PROVIDER.pop() self.worker = worker or sy.local_worker self.actions = actions or [] # All placeholders self.placeholders = placeholders or {} # Input placeholders, stored by id self.input_placeholder_ids = input_placeholder_ids or () # Output placeholders self.output_placeholder_ids = output_placeholder_ids or () self.state = state or State() self.tracing = False for name, package in framework_packages.items(): tracing_wrapper = FrameworkWrapper(package=package, role=self) setattr(self, name, tracing_wrapper)
def __init__( self, name: str = None, include_state: bool = False, is_built: bool = False, forward_func=None, state_tensors=[], role: Role = None, # General kwargs id: Union[str, int] = None, owner: "sy.workers.BaseWorker" = None, tags: List[str] = None, input_types: list = None, description: str = None, ): AbstractObject.__init__(self, id, owner, tags, description, child=None) # Plan instance info self.name = name or self.__class__.__name__ self.role = role or Role() if role is None: for st in state_tensors: self.role.register_state_tensor(st) self.include_state = include_state self.is_building = False self.state_attributes = {} self.is_built = is_built self.torchscript = None self.input_types = input_types self.tracing = False # The plan has not been sent so it has no reference to remote locations self.pointers = dict() if not hasattr(self, "forward"): self.forward = forward_func or None """ When we use methods defined in a framework (like: torch.randn) we have a framework wrapper that helps as register and keep track of what methods are called With the below lines, we "register" what frameworks we have support to handle """ self.wrapped_framework = {} for f_name, f_packages in framework_packages.items(): self.wrapped_framework[f_name] = FrameworkWrapper( f_packages, self.role, self.owner) self.__name__ = self.__repr__( ) # For PyTorch jit tracing compatibility # List of available translations self.translations = []
state_names = { id: f"state_{i + 1}" for i, id in enumerate(self.role.state.state_placeholders) } var_names = {**input_names, **output_names, **state_names} out = f"def {self.name}(" out += ", ".join( [var_names[id] for id in self.role.input_placeholder_ids]) out += "):\n" for action in self.role.actions: out += f" {action.code(var_names)}\n" out += " return " out += ", ".join( [var_names[id] for id in self.role.output_placeholder_ids]) return out @staticmethod def get_protobuf_schema() -> PlanPB: return PlanPB # Auto-register Plan build-time translations Plan.register_build_translator(PlanTranslatorTorchscript) # Auto-register Plan build-time frameworks for f_name, f_package in framework_packages.items(): Plan.register_framework(f_name, f_package)