def test_plan_translated_on_build(hook, workers):
    # Enable torchscript translator
    Plan.register_build_translator(PlanTranslatorTorchscript)

    @sy.func2plan(args_shape=[(3, 3)])
    def plan(x):
        x = x * 2
        x = x.abs()
        return x

    inp = th.tensor([1, -1, 2])
    res1 = plan(inp)
    res2 = plan.torchscript(inp)
    assert (res1 == res2).all()
Exemple #2
0
def test_plan_can_be_jit_traced(hook, workers):
    args_shape = [(1,)]

    @sy.func2plan(args_shape=args_shape, state=(th.tensor([1.0]),))
    def foo(x, state):
        (bias,) = state.read()
        x = x * 2
        return x + bias

    assert isinstance(foo.__str__(), str)
    assert len(foo.readable_plan) > 0
    assert foo.is_built

    t = th.tensor([1.0, 2])
    x = foo(t)

    assert (x == th.tensor([3.0, 5])).all()

    args = Plan._create_placeholders(args_shape)
    torchscript_plan = th.jit.trace(foo, args)

    y = torchscript_plan(t)

    assert (y == th.tensor([3.0, 5])).all()
Exemple #3
0
            if 0 < len(role.state.state_placeholders) == len(state) and isinstance(
                state, (list, tuple)
            ):
                state_placeholders = tuple(
                    role.placeholders[ph.id.value] for ph in role.state.state_placeholders
                )
                PlaceHolder.instantiate_placeholders(role.state.state_placeholders, state)
                PlaceHolder.instantiate_placeholders(state_placeholders, state)

            return plan(*args[:-1])

        plan_params = plan.parameters()
        if len(plan_params) > 0:
            torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params))
        else:
            torchscript_plan = jit.trace(plan, args)
        plan.torchscript = torchscript_plan
        plan.forward = tmp_forward

        return plan

    def remove(self):
        plan = self.plan
        plan.torchscript = None

        return plan


# Register translators that should apply at Plan build time
Plan.register_build_translator(PlanTranslatorTorchscript)