def create_dummy(input_type, input_placeholder): if issubclass(input_type, FrameworkTensor): return input_type( PlaceHolder.create_placeholders( [input_placeholder.expected_shape])[0]) else: return input_type()
def translate(self): translation_plan = self.plan.copy() translation_plan.forward = None args_shape = translation_plan.get_args_shape() args = PlaceHolder.create_placeholders(args_shape) # To avoid storing Plan state tensors in torchscript, they will be send as parameters # we trace wrapper func, which accepts state parameters as last arg # and sets them into the Plan before executing the Plan def wrap_stateful_plan(*args): role = translation_plan.role state = args[-1] if 0 < len(role.state.state_placeholders) == len( state) and isinstance(state, (list, tuple)): state_placeholders = tuple( role.placeholders[ph.id.value] for ph in role.state.state_placeholders) PlaceHolder.instantiate_placeholders( role.state.state_placeholders, state) PlaceHolder.instantiate_placeholders(state_placeholders, state) return translation_plan(*args[:-1]) plan_params = translation_plan.parameters() if len(plan_params) > 0: torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params)) else: torchscript_plan = jit.trace(translation_plan, args) self.plan.torchscript = torchscript_plan return self.plan
def create_dummy(input_type, input_placeholder): if issubclass(input_type, FrameworkTensor): tensors = PlaceHolder.create_placeholders( [input_placeholder.expected_shape], [input_placeholder.expected_dtype]) var = tensors[0] if input_type != type(var): var = input_type(var) return var else: return input_type()
def translate(self): plan = self.plan args_shape = plan.get_args_shape() args = PlaceHolder.create_placeholders(args_shape) # Temporarily remove reference to original function tmp_forward = plan.forward plan.forward = None # To avoid storing Plan state tensors in torchscript, they will be send as parameters plan_params = plan.parameters() if len(plan_params) > 0: args = (*args, plan_params) torchscript_plan = jit.trace(plan, args) plan.torchscript = torchscript_plan plan.forward = tmp_forward return plan
def __call__(self, protocol_function): protocol = Protocol( name=protocol_function.__name__, forward_func=protocol_function, id=sy.ID_PROVIDER.pop(), owner=sy.local_worker, ) # Build the protocol automatically if self.args_shape: args_ = PlaceHolder.create_placeholders(self.args_shape) try: protocol.build(*args_) except TypeError as e: raise ValueError( "Automatic build using @func2protocol failed!\nCheck that:\n" " - you have provided the correct number of shapes in args_shape\n" " - you have no simple numbers like int or float as args. If you do " "so, please consider using a tensor instead." ) return protocol
def __call__(self, plan_function): plan = Plan( name=plan_function.__name__, include_state=self.include_state, forward_func=plan_function, state_tensors=self.state_tensors, id=sy.ID_PROVIDER.pop(), owner=sy.local_worker, ) # Build the plan automatically if self.args_shape: args_ = PlaceHolder.create_placeholders(self.args_shape) try: plan.build(*args_, trace_autograd=self.trace_autograd) except TypeError as e: raise ValueError( "Automatic build using @func2plan failed!\nCheck that:\n" " - you have provided the correct number of shapes in args_shape\n" " - you have no simple numbers like int or float as args. If you do " "so, please consider using a tensor instead.") return plan
def test_plan_can_be_jit_traced(hook, workers): args_shape = [(1, )] @sy.func2plan(args_shape=args_shape, state=(th.tensor([1.0]), )) def foo(x, state): (bias, ) = state.read() x = x * 2 return x + bias assert isinstance(foo.__str__(), str) assert len(foo.actions) > 0 assert foo.is_built t = th.tensor([1.0, 2]) x = foo(t) assert (x == th.tensor([3.0, 5])).all() args = PlaceHolder.create_placeholders(args_shape) torchscript_plan = th.jit.trace(foo, args) y = torchscript_plan(t) assert (y == th.tensor([3.0, 5])).all()
def translate(self): plan = self.plan args_shape = plan.get_args_shape() args = PlaceHolder.create_placeholders(args_shape) # Temporarily remove reference to original function tmp_forward = plan.forward plan.forward = None # To avoid storing Plan state tensors inside the torchscript, # we trace wrapper func, which accepts state parameters as last arg # and sets them into the Plan before executing the Plan def wrap_stateful_plan(*args): role = plan.role state = args[-1] if 0 < len(role.state.state_placeholders) == len(state) and isinstance( state, (list, tuple) ): state_placeholders = tuple( role.placeholders[ph.id.value] for ph in role.state.state_placeholders ) PlaceHolder.instantiate_placeholders(role.state.state_placeholders, state) PlaceHolder.instantiate_placeholders(state_placeholders, state) return plan(*args[:-1]) plan_params = plan.parameters() if len(plan_params) > 0: torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params)) else: torchscript_plan = jit.trace(plan, args) plan.torchscript = torchscript_plan plan.forward = tmp_forward return plan