def __init__(self, owner=None, id=None, tags: set = None, description: str = None, shape=None): """A PlaceHolder acts as a tensor but does nothing special. It can get "instantiated" when a real tensor is appended as a child attribute. It will send forward all the commands it receives to its child tensor. When you send a PlaceHolder, you don't sent the instantiated tensors. Args: owner: An optional BaseWorker object to specify the worker on which the tensor is located. id: An optional string or integer id of the PlaceHolder. """ super().__init__(id=id, owner=owner, tags=tags, description=description) if not isinstance(self.id, PlaceholderId): self.id = PlaceholderId(self.id) self.expected_shape = tuple(shape) if shape is not None else None self.child = None
def read(self): """ Return state tensors that are from this plan specifically, but not those of plans including in this plan. If run while a plan is building, declare all the state tensors to the plan currently building. """ # If there is a plan building, it is referenced in init_plan if self.owner.init_plan: parent_plan = self.owner.init_plan # to see if we are in a sub plan, we use state objects equality if parent_plan.state != self: # for all the placeholders in this sub plan, we report a copy of them # in the parent plan and notify their origin using the #inner tag for placeholder in self.state_placeholders: placeholder = placeholder.copy() placeholder.id = PlaceholderId(placeholder.child.id) placeholder.tags = set() placeholder.tag("#inner") parent_plan.state.state_placeholders.append(placeholder) parent_plan.role.placeholders[ placeholder.child.id] = placeholder tensors = [] for placeholder in self.state_placeholders: # State elements from sub plan should not be reported when read() is used if "#inner" not in placeholder.tags: tensor = placeholder.child tensors.append(tensor) return tensors
def read(self): """ Return state tensors that are from this plan specifically, but not those of plans including in this plan. If run while a plan is building, declare all the state tensors to the plan currently building. """ # TODO clean this function # If there is a plan building, it is referenced in init_plan if self.owner.init_plan: parent_plan = self.owner.init_plan # to see if we are in a sub plan, we use state objects equality if parent_plan.state != self: # for all the placeholders in this sub plan, we report a copy of them # in the parent plan and notify their origin using the #inner tag for placeholder in self.state_placeholders: placeholder = placeholder.copy() placeholder.id = PlaceholderId(placeholder.child.id) placeholder.tags = set(("#inner", )) parent_plan.state.state_placeholders.append(placeholder) parent_plan.role.placeholders[ placeholder.child.id] = placeholder if self.read_placeholders or self.owner.init_plan: return [ ph for ph in self.state_placeholders if "#inner" not in ph.tags ] else: return [ ph.child for ph in self.state_placeholders if "#inner" not in ph.tags ]
def _replace_placeholder_ids(obj): if isinstance(obj, (tuple, list)): r = [_replace_placeholder_ids(o) for o in obj] return type(obj)(r) elif isinstance(obj, dict): return {key: _replace_placeholder_ids(value) for key, value in obj.items()} elif isinstance(obj, PlaceholderId): return PlaceholderId(old_ids_2_new_ids[obj.value]) else: return obj
def add_placeholder(self, tensor, arg_ids, result_ids, node_type=None): """ Create and register a new placeholder if not already existing (else return the existing one). The placeholder is tagged by a unique and incremental index for a given plan. Args: tensor: the tensor to replace with a placeholder node_type: Should be "input" or "output", used to tag like this: #<type>-* """ if tensor.id not in self.placeholders: placeholder = sy.PlaceHolder(tags={f"#{self.var_count + 1}"}, id=tensor.id, owner=self.owner) self.placeholders[tensor.id] = placeholder if node_type == "input": if tensor.id not in arg_ids: raise ValueError( f"The following tensor was used but is not known in " f"this plan: \n{tensor}\nPossible reasons for this can be:\n" f"- This tensor is external to the plan and should be provided " f"using the state. See more about plan.state to fix this.\n" f"- This tensor was created internally using torch.Tensor, " f"torch.FloatTensor, torch.IntTensor, etc, which are not supported. " f"Please use instead torch.tensor(..., dtype=torch.int32) for example." ) placeholder.tags.add(f"#input-{arg_ids.index(tensor.id)}") if tensor.id in result_ids: placeholder.tags.add( f"#output-{result_ids.index(tensor.id)}") elif node_type == "output": if tensor.id in result_ids: placeholder.tags.add( f"#output-{result_ids.index(tensor.id)}") if tensor.id in arg_ids: placeholder.tags.add( f"#input-{result_ids.index(tensor.id)}") else: raise ValueError("node_type should be 'input' or 'output'.") self.var_count += 1 return PlaceholderId(tensor.id)
def copy(self): # TODO not the cleanest method ever placeholders = {} old_ids_2_new_ids = {} for ph in self.placeholders.values(): copy = ph.copy() old_ids_2_new_ids[ph.id.value] = copy.id.value placeholders[copy.id.value] = copy new_input_placeholder_ids = tuple( old_ids_2_new_ids[self.placeholders[input_id].id.value] for input_id in self.input_placeholder_ids) new_output_placeholder_ids = tuple( old_ids_2_new_ids[self.placeholders[output_id].id.value] for output_id in self.output_placeholder_ids) state_placeholders = [] for ph in self.state.state_placeholders: new_ph = PlaceHolder( id=old_ids_2_new_ids[ph.id.value]).instantiate(ph.child) state_placeholders.append(new_ph) state = State(state_placeholders) _replace_placeholder_ids = lambda obj: Role.nested_object_traversal( obj, lambda x: PlaceholderId(old_ids_2_new_ids[x.value]), PlaceholderId) new_actions = [] for action in self.actions: action_type = type(action) target = _replace_placeholder_ids(action.target) args_ = _replace_placeholder_ids(action.args) kwargs_ = _replace_placeholder_ids(action.kwargs) return_ids = _replace_placeholder_ids(action.return_ids) new_actions.append( action_type(action.name, target, args_, kwargs_, return_ids)) return Role( state=state, actions=new_actions, placeholders=placeholders, input_placeholder_ids=new_input_placeholder_ids, output_placeholder_ids=new_output_placeholder_ids, id=sy.ID_PROVIDER.pop(), )