def test_plan_translated_on_build(hook, workers): # Enable torchscript translator Plan.register_build_translator(PlanTranslatorTorchscript) @sy.func2plan(args_shape=[(3, 3)]) def plan(x): x = x * 2 x = x.abs() return x inp = th.tensor([1, -1, 2]) res1 = plan(inp) res2 = plan.torchscript(inp) assert (res1 == res2).all()
def test_plan_can_be_jit_traced(hook, workers): args_shape = [(1,)] @sy.func2plan(args_shape=args_shape, state=(th.tensor([1.0]),)) def foo(x, state): (bias,) = state.read() x = x * 2 return x + bias assert isinstance(foo.__str__(), str) assert len(foo.readable_plan) > 0 assert foo.is_built t = th.tensor([1.0, 2]) x = foo(t) assert (x == th.tensor([3.0, 5])).all() args = Plan._create_placeholders(args_shape) torchscript_plan = th.jit.trace(foo, args) y = torchscript_plan(t) assert (y == th.tensor([3.0, 5])).all()
if 0 < len(role.state.state_placeholders) == len(state) and isinstance( state, (list, tuple) ): state_placeholders = tuple( role.placeholders[ph.id.value] for ph in role.state.state_placeholders ) PlaceHolder.instantiate_placeholders(role.state.state_placeholders, state) PlaceHolder.instantiate_placeholders(state_placeholders, state) return plan(*args[:-1]) plan_params = plan.parameters() if len(plan_params) > 0: torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params)) else: torchscript_plan = jit.trace(plan, args) plan.torchscript = torchscript_plan plan.forward = tmp_forward return plan def remove(self): plan = self.plan plan.torchscript = None return plan # Register translators that should apply at Plan build time Plan.register_build_translator(PlanTranslatorTorchscript)