Ejemplo n.º 1
0
def test_double_stash_pop_but_isolated():
    @skippable(stash=["foo"])
    class Layer1(nn.Module):
        pass

    @skippable(pop=["foo"])
    class Layer2(nn.Module):
        pass

    @skippable(stash=["foo"])
    class Layer3(nn.Module):
        pass

    @skippable(pop=["foo"])
    class Layer4(nn.Module):
        pass

    ns1 = Namespace()
    ns2 = Namespace()

    verify_skippables(
        nn.Sequential(
            Layer1().isolate(ns1),
            Layer2().isolate(ns1),
            Layer3().isolate(ns2),
            Layer4().isolate(ns2),
        ))
Ejemplo n.º 2
0
def test_namespace():
    ns1 = Namespace()
    ns2 = Namespace()

    p1 = nn.Sequential(StashFoo().isolate(ns1))
    p2 = nn.Sequential(StashFoo().isolate(ns2))
    p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1))

    layout = inspect_skip_layout([p1, p2, p3])
    policy = [list(layout.copy_policy(i)) for i in range(3)]

    # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
    assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]]
Ejemplo n.º 3
0
def test_namespace_copy():
    ns = Namespace()
    assert copy.copy(ns) == ns
    assert copy.copy(ns) is not ns
Ejemplo n.º 4
0
def test_namespace_difference():
    ns1 = Namespace()
    ns2 = Namespace()
    assert ns1 != ns2