def test_is_pool_module(pool_modules, conv_modules): for module in pool_modules: msg = (f"{module.__class__.__name__} is a pool module, but it is not " f"recognized as one.") assert meta.is_pool_module(module), msg for module in conv_modules: msg = (f"{module.__class__.__name__} is not a pool module, but it is " f"recognized as one.") assert not meta.is_pool_module(module), msg
def test_is_pool_module(self): for module in default_pool_modules(): msg = ( f"{module.__class__.__name__} is a pool module, but it is not " f"recognized as one.") self.assertTrue(meta.is_pool_module(module), msg) for module in default_conv_modules(): msg = ( f"{module.__class__.__name__} is not a pool module, but it is " f"recognized as one.") self.assertFalse(meta.is_pool_module(module), msg)
def test_multi_layer_encoder_avg_pool(mocker): mocks.patch_models_load_state_dict_from_url(mocker=mocker) multi_layer_encoder = paper.multi_layer_encoder(impl_params=False) pool_modules = [ module for module in multi_layer_encoder.modules() if meta.is_pool_module(module) ] assert all(isinstance(module, nn.AvgPool2d) for module in pool_modules)
def propagate_guide( module: nn.Module, guide: torch.Tensor, method: str = "simple", allow_empty: bool = False, ) -> torch.Tensor: verify_str_arg(method, "method", ("simple", "inside", "all")) if is_conv_module(module): guide = _conv_guide(cast(ConvModule, module), guide, method) elif is_pool_module(module): guide = _pool_guide(cast(PoolModule, module), guide) if allow_empty or torch.any(guide.bool()): return guide msg = ( f"Guide has no longer any entries after propagation through " f"{module.__class__.__name__}({module.extra_repr()}). If this is valid, " f"set allow_empty=True.") raise RuntimeError(msg)