def test_plan_translated_on_build(hook, workers): # Enable torchscript translator Plan.register_build_translator(PlanTranslatorTorchscript) @sy.func2plan(args_shape=[(3, 3)]) def plan(x): x = x * 2 x = x.abs() return x inp = th.tensor([1, -1, 2]) res1 = plan(inp) res2 = plan.torchscript(inp) assert (res1 == res2).all()
if 0 < len(role.state.state_placeholders) == len(state) and isinstance( state, (list, tuple) ): state_placeholders = tuple( role.placeholders[ph.id.value] for ph in role.state.state_placeholders ) PlaceHolder.instantiate_placeholders(role.state.state_placeholders, state) PlaceHolder.instantiate_placeholders(state_placeholders, state) return plan(*args[:-1]) plan_params = plan.parameters() if len(plan_params) > 0: torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params)) else: torchscript_plan = jit.trace(plan, args) plan.torchscript = torchscript_plan plan.forward = tmp_forward return plan def remove(self): plan = self.plan plan.torchscript = None return plan # Register translators that should apply at Plan build time Plan.register_build_translator(PlanTranslatorTorchscript)