def __init__( self, arch: str = "resnet18", pretrained: bool = True, frozen: bool = True, pooling: str = None, pooling_kwargs: dict = None, cut_layers: int = 2, state_dict: Union[dict, str, Path] = None, ): """ Args: arch: Name for resnet. Have to be one of resnet18, resnet34, resnet50, resnet101, resnet152 pretrained: If True, returns a model pre-trained on ImageNet frozen: If frozen, sets requires_grad to False pooling: pooling pooling_kwargs: params for pooling state_dict (Union[dict, str, Path]): Path to ``torch.Model`` or a dict containing parameters and persistent buffers. """ super().__init__() resnet = torchvision.models.__dict__[arch](pretrained=pretrained) if state_dict is not None: if isinstance(state_dict, (Path, str)): state_dict = torch.load(str(state_dict)) resnet.load_state_dict(state_dict) modules = list(resnet.children())[:-cut_layers] # delete last layers if frozen: for module in modules: utils.set_requires_grad(module, requires_grad=False) if pooling is not None: pooling_kwargs = pooling_kwargs or {} pooling_layer_fn = MODULE.get(pooling) pooling_layer = ( pooling_layer_fn( in_features=resnet.fc.in_features, **pooling_kwargs ) if "attn" in pooling.lower() else pooling_layer_fn(**pooling_kwargs) ) modules += [pooling_layer] if hasattr(pooling_layer, "out_features"): out_features = pooling_layer.out_features( in_features=resnet.fc.in_features ) else: out_features = None else: out_features = resnet.fc.in_features modules += [Flatten()] self.out_features = out_features self.encoder = nn.Sequential(*modules)
class EncoderDecoderNet(nn.Module): """Generalized Encoder-Decoder network. Args: encoder: Encoder module, usually used for the extraction of embeddings from input signals. decoder: Decoder module, usually used for embeddings processing e.g. generation of signal similar to the input one (in GANs). """ def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None: super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass method. Args: x: Batch of input signals e.g. images. Returns: Batch of generated signals e.g. images. """ x = self.encoder(x) x = self.decoder(x) x = torch.clamp(x, min=0.0, max=1.0) return x @classmethod def get_from_params( cls, encoder_params: Optional[dict] = None, decoder_params: Optional[dict] = None, ) -> "EncoderDecoderNet": """Create model based on it config. Args: encoder_params: Encoder module params. decoder_params: Decoder module parameters. Returns: Model. """ encoder: nn.Module = nn.Identity() if (encoder_params_ := copy.deepcopy(encoder_params)) is not None: encoder_fn = MODULE.get(encoder_params_.pop("module")) encoder = encoder_fn(**encoder_params_) decoder: nn.Module = nn.Identity() if (decoder_params_ := copy.deepcopy(decoder_params)) is not None: decoder_fn = MODULE.get(decoder_params_.pop("module")) decoder = decoder_fn(**decoder_params_)
def _process_fn_params(params: ModuleParams, key: Optional[str] = None) -> Callable[..., nn.Module]: module_fn: Callable[..., nn.Module] if callable(params): module_fn = params elif isinstance(params, str): name = params module_fn = MODULE.get(name) elif isinstance(params, dict) and key is not None: params = copy.deepcopy(params) name_or_fn = params.pop(key) module_fn = _process_fn_params(name_or_fn) module_fn = functools.partial(module_fn, **params) else: NotImplementedError() return module_fn
def get_from_params( cls, encoder_params: Optional[dict] = None, decoder_params: Optional[dict] = None, ) -> "EncoderDecoderNet": """Create model based on it config. Args: encoder_params: Encoder module params. decoder_params: Decoder module parameters. Returns: Model. """ encoder: nn.Module = nn.Identity() if (encoder_params_ := copy.deepcopy(encoder_params)) is not None: encoder_fn = MODULE.get(encoder_params_.pop("module")) encoder = encoder_fn(**encoder_params_)
def get_from_params( cls, encoder_params: Optional[dict] = None, pooling_params: Optional[dict] = None, head_params: Optional[dict] = None, ) -> "VGGConv": """Create model based on it config. Args: encoder_params: Params of encoder module. pooling_params: Params of the pooling layer. head_params: 'Head' module params. Returns: Model. """ encoder: nn.Module = nn.Identity() if (encoder_params_ := copy.deepcopy(encoder_params)) is not None: encoder_fn = MODULE.get(encoder_params_.pop("module")) encoder = encoder_fn(**encoder_params_)
class VGGConv(nn.Module): """VGG-like neural network for image classification. Args: encoder: Image encoder module, usually used for the extraction of embeddings from input signals. pool: Pooling layer, used to reduce embeddings from the encoder. head: Classification head, usually consists of Fully Connected layers. """ def __init__( self, encoder: nn.Module, pool: nn.Module, head: nn.Module, ) -> None: super().__init__() self.encoder = encoder self.pool = pool self.head = head def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward call. Args: x: Batch of images. Returns: Batch of logits. """ x = self.pool(self.encoder(x)) x = x.view(x.shape[0], -1) x = self.head(x) return x @classmethod def get_from_params( cls, encoder_params: Optional[dict] = None, pooling_params: Optional[dict] = None, head_params: Optional[dict] = None, ) -> "VGGConv": """Create model based on it config. Args: encoder_params: Params of encoder module. pooling_params: Params of the pooling layer. head_params: 'Head' module params. Returns: Model. """ encoder: nn.Module = nn.Identity() if (encoder_params_ := copy.deepcopy(encoder_params)) is not None: encoder_fn = MODULE.get(encoder_params_.pop("module")) encoder = encoder_fn(**encoder_params_) pool: nn.Module = nn.Identity() if (pooling_params_ := copy.deepcopy(pooling_params)) is not None: pool_fn = MODULE.get(pooling_params_.pop("module")) pool = pool_fn(**pooling_params_)
encoder_params: Params of encoder module. pooling_params: Params of the pooling layer. head_params: 'Head' module params. Returns: Model. """ encoder: nn.Module = nn.Identity() if (encoder_params_ := copy.deepcopy(encoder_params)) is not None: encoder_fn = MODULE.get(encoder_params_.pop("module")) encoder = encoder_fn(**encoder_params_) pool: nn.Module = nn.Identity() if (pooling_params_ := copy.deepcopy(pooling_params)) is not None: pool_fn = MODULE.get(pooling_params_.pop("module")) pool = pool_fn(**pooling_params_) head: nn.Module = nn.Identity() if (head_params_ := copy.deepcopy(head_params)) is not None: head_fn = MODULE.get(head_params_.pop("module")) head = head_fn(**head_params_) net = cls(encoder=encoder, pool=pool, head=head) utils.net_init_(net) return net __all__ = ["VGGConv"]