Example #1
0
def test_is_pool_module(pool_modules, conv_modules):
    for module in pool_modules:
        msg = (f"{module.__class__.__name__} is a pool module, but it is not "
               f"recognized as one.")
        assert meta.is_pool_module(module), msg

    for module in conv_modules:
        msg = (f"{module.__class__.__name__} is not a pool module, but it is "
               f"recognized as one.")
        assert not meta.is_pool_module(module), msg
Example #2
0
    def test_is_pool_module(self):
        for module in default_pool_modules():
            msg = (
                f"{module.__class__.__name__} is a pool module, but it is not "
                f"recognized as one.")
            self.assertTrue(meta.is_pool_module(module), msg)

        for module in default_conv_modules():
            msg = (
                f"{module.__class__.__name__} is not a pool module, but it is "
                f"recognized as one.")
            self.assertFalse(meta.is_pool_module(module), msg)
Example #3
0
def test_multi_layer_encoder_avg_pool(mocker):
    mocks.patch_models_load_state_dict_from_url(mocker=mocker)

    multi_layer_encoder = paper.multi_layer_encoder(impl_params=False)
    pool_modules = [
        module for module in multi_layer_encoder.modules()
        if meta.is_pool_module(module)
    ]
    assert all(isinstance(module, nn.AvgPool2d) for module in pool_modules)
Example #4
0
def propagate_guide(
    module: nn.Module,
    guide: torch.Tensor,
    method: str = "simple",
    allow_empty: bool = False,
) -> torch.Tensor:
    verify_str_arg(method, "method", ("simple", "inside", "all"))
    if is_conv_module(module):
        guide = _conv_guide(cast(ConvModule, module), guide, method)
    elif is_pool_module(module):
        guide = _pool_guide(cast(PoolModule, module), guide)

    if allow_empty or torch.any(guide.bool()):
        return guide

    msg = (
        f"Guide has no longer any entries after propagation through "
        f"{module.__class__.__name__}({module.extra_repr()}). If this is valid, "
        f"set allow_empty=True.")
    raise RuntimeError(msg)