def test_double_stash_pop_but_isolated(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"]) class Layer2(nn.Module): pass @skippable(stash=["foo"]) class Layer3(nn.Module): pass @skippable(pop=["foo"]) class Layer4(nn.Module): pass ns1 = Namespace() ns2 = Namespace() verify_skippables( nn.Sequential( Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2), ))
def test_namespace(): ns1 = Namespace() ns2 = Namespace() p1 = nn.Sequential(StashFoo().isolate(ns1)) p2 = nn.Sequential(StashFoo().isolate(ns2)) p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) layout = inspect_skip_layout([p1, p2, p3]) policy = [list(layout.copy_policy(i)) for i in range(3)] # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]]
def test_namespace_copy(): ns = Namespace() assert copy.copy(ns) == ns assert copy.copy(ns) is not ns
def test_namespace_difference(): ns1 = Namespace() ns2 = Namespace() assert ns1 != ns2