def test_double_stash_pop_but_isolated(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"]) class Layer2(nn.Module): pass @skippable(stash=["foo"]) class Layer3(nn.Module): pass @skippable(pop=["foo"]) class Layer4(nn.Module): pass ns1 = Namespace() ns2 = Namespace() verify_skippables( nn.Sequential( Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2), ))
def test_stash_not_pop(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1())) assert "no module declared 'foo' as poppable but stashed" in str(e.value)
def test_pop_unknown(): @skippable(pop=["foo"]) class Layer1(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1())) assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value)
def test_stash_pop_together_same_name(): @skippable(stash=["foo"], pop=["foo"]) class Layer1(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1())) assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value)
def test_matching(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"]) class Layer2(nn.Module): pass verify_skippables(nn.Sequential(Layer1(), Layer2()))
def test_stash_pop_together_different_names(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"], stash=["bar"]) class Layer2(nn.Module): pass @skippable(pop=["bar"]) class Layer3(nn.Module): pass verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
def test_pop_again(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"]) class Layer2(nn.Module): pass @skippable(pop=["foo"]) class Layer3(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) assert "'2' redeclared 'foo' as poppable" in str(e.value)