Пример #1
0
def test_double_stash_pop_but_isolated():
    @skippable(stash=['foo'])
    class Layer1(nn.Module):
        pass

    @skippable(pop=['foo'])
    class Layer2(nn.Module):
        pass

    @skippable(stash=['foo'])
    class Layer3(nn.Module):
        pass

    @skippable(pop=['foo'])
    class Layer4(nn.Module):
        pass

    ns1 = Namespace()
    ns2 = Namespace()

    verify_skippables(
        nn.Sequential(
            Layer1().isolate(ns1),
            Layer2().isolate(ns1),
            Layer3().isolate(ns2),
            Layer4().isolate(ns2),
        ))
Пример #2
0
def test_stash_not_pop():
    @skippable(stash=['foo'])
    class Layer1(nn.Module):
        pass

    with pytest.raises(TypeError) as e:
        verify_skippables(nn.Sequential(Layer1()))
    assert "no module declared 'foo' as poppable but stashed" in str(e.value)
Пример #3
0
def test_stash_pop_together_same_name():
    @skippable(stash=['foo'], pop=['foo'])
    class Layer1(nn.Module):
        pass

    with pytest.raises(TypeError) as e:
        verify_skippables(nn.Sequential(Layer1()))
    assert "'0' declared 'foo' both as stashable and as poppable" in str(
        e.value)
Пример #4
0
def test_pop_unknown():
    @skippable(pop=['foo'])
    class Layer1(nn.Module):
        pass

    with pytest.raises(TypeError) as e:
        verify_skippables(nn.Sequential(Layer1()))
    assert "'0' declared 'foo' as poppable but it was not stashed" in str(
        e.value)
Пример #5
0
def test_matching():
    @skippable(stash=['foo'])
    class Layer1(nn.Module):
        pass

    @skippable(pop=['foo'])
    class Layer2(nn.Module):
        pass

    verify_skippables(nn.Sequential(Layer1(), Layer2()))
Пример #6
0
def test_stash_pop_together_different_names():
    @skippable(stash=['foo'])
    class Layer1(nn.Module):
        pass

    @skippable(pop=['foo'], stash=['bar'])
    class Layer2(nn.Module):
        pass

    @skippable(pop=['bar'])
    class Layer3(nn.Module):
        pass

    verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
Пример #7
0
def test_pop_again():
    @skippable(stash=['foo'])
    class Layer1(nn.Module):
        pass

    @skippable(pop=['foo'])
    class Layer2(nn.Module):
        pass

    @skippable(pop=['foo'])
    class Layer3(nn.Module):
        pass

    with pytest.raises(TypeError) as e:
        verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
    assert "'2' redeclared 'foo' as poppable" in str(e.value)