Exemplo n.º 1
0
def test_verify_module_duplicate_children():
    conv = nn.Conv2d(3, 3, 1)
    model = nn.Sequential(conv, conv)

    with pytest.raises(
            ValueError,
            match='module with duplicate children is not supported'):
        verify_module(model)
Exemplo n.º 2
0
def test_verify_module_duplicate_parameters_in_distinct_children():
    class Surrogate(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

    conv = nn.Conv2d(3, 3, 1)
    model = nn.Sequential(Surrogate(conv), Surrogate(conv))

    with pytest.raises(ValueError,
                       match='module with duplicate parameters in '
                       'distinct children is not supported'):
        verify_module(model)
Exemplo n.º 3
0
def test_verify_module_non_sequential():
    with pytest.raises(TypeError,
                       match='module must be nn.Sequential to be partitioned'):
        verify_module(nn.Module())