Beispiel #1
0
def test_is_conv_module(conv_modules, pool_modules):
    for module in conv_modules:
        msg = (f"{module.__class__.__name__} is a conv module, but it is not "
               f"recognized as one.")
        assert meta.is_conv_module(module), msg

    for module in pool_modules:
        msg = (f"{module.__class__.__name__} is not a conv module, but it is "
               f"recognized as one.")
        assert not meta.is_conv_module(module), msg
Beispiel #2
0
    def test_is_conv_module(self):
        for module in default_conv_modules():
            msg = (
                f"{module.__class__.__name__} is a conv module, but it is not "
                f"recognized as one.")
            self.assertTrue(meta.is_conv_module(module), msg)

        for module in default_pool_modules():
            msg = (
                f"{module.__class__.__name__} is not a conv module, but it is "
                f"recognized as one.")
            self.assertFalse(meta.is_conv_module(module), msg)
Beispiel #3
0
def propagate_guide(
    module: nn.Module,
    guide: torch.Tensor,
    method: str = "simple",
    allow_empty: bool = False,
) -> torch.Tensor:
    verify_str_arg(method, "method", ("simple", "inside", "all"))
    if is_conv_module(module):
        guide = _conv_guide(cast(ConvModule, module), guide, method)
    elif is_pool_module(module):
        guide = _pool_guide(cast(PoolModule, module), guide)

    if allow_empty or torch.any(guide.bool()):
        return guide

    msg = (
        f"Guide has no longer any entries after propagation through "
        f"{module.__class__.__name__}({module.extra_repr()}). If this is valid, "
        f"set allow_empty=True.")
    raise RuntimeError(msg)