def __init__( self, input_size=224, width_mult=1., pretrained=True, pooling=None, pooling_kwargs=None, ): super().__init__() net = MobileNetV2(input_size=input_size, width_mult=width_mult, pretrained=pretrained) self.encoder = list(net.encoder.children()) if pooling is not None: pooling_kwargs = pooling_kwargs or {} pooling_layer_fn = MODULES.get(pooling) pooling_layer = pooling_layer_fn( in_features=self.last_channel, **pooling_kwargs) \ if "attn" in pooling.lower() \ else pooling_layer_fn(**pooling_kwargs) self.encoder.append(pooling_layer) out_features = pooling_layer.out_features( in_features=net.output_channel) else: out_features = net.output_channel self.out_features = out_features # make it nn.Sequential self.encoder = nn.Sequential(*self.encoder) self._initialize_weights()
def __init__(self, arch="resnet34", pretrained=True, frozen=True, pooling=None, pooling_kwargs=None, cut_layers=2): super().__init__() resnet = torchvision.models.__dict__[arch](pretrained=pretrained) modules = list(resnet.children())[:-cut_layers] # delete last layers if frozen: for module in modules: for param in module.parameters(): param.requires_grad = False if pooling is not None: pooling_kwargs = pooling_kwargs or {} pooling_layer_fn = MODULES.get(pooling) pooling_layer = pooling_layer_fn( in_features=resnet.fc.in_features, **pooling_kwargs) \ if "attn" in pooling.lower() \ else pooling_layer_fn(**pooling_kwargs) modules += [pooling_layer] out_features = pooling_layer.out_features( in_features=resnet.fc.in_features) else: out_features = resnet.fc.in_features modules += [Flatten()] self.out_features = out_features self.encoder = nn.Sequential(*modules)
def __init__( self, arch: str = "resnet18", pretrained: bool = True, frozen: bool = True, pooling: str = None, pooling_kwargs: dict = None, cut_layers: int = 2, state_dict: Union[dict, str, Path] = None, ): """ Args: arch (str): Name for resnet. Have to be one of resnet18, resnet34, resnet50, resnet101, resnet152 pretrained (bool): If True, returns a model pre-trained on ImageNet frozen (bool): If frozen, sets requires_grad to False pooling (str): pooling pooling_kwargs (dict): params for pooling state_dict (Union[dict, str, Path]): Path to ``torch.Model`` or a dict containing parameters and persistent buffers. """ super().__init__() resnet = torchvision.models.__dict__[arch](pretrained=pretrained) if state_dict is not None: if isinstance(state_dict, (Path, str)): state_dict = torch.load(str(state_dict)) resnet.load_state_dict(state_dict) modules = list(resnet.children())[:-cut_layers] # delete last layers if frozen: for module in modules: utils.set_requires_grad(module, requires_grad=False) if pooling is not None: pooling_kwargs = pooling_kwargs or {} pooling_layer_fn = MODULES.get(pooling) pooling_layer = (pooling_layer_fn( in_features=resnet.fc.in_features, **pooling_kwargs) if "attn" in pooling.lower() else pooling_layer_fn( **pooling_kwargs)) modules += [pooling_layer] if hasattr(pooling_layer, "out_features"): out_features = pooling_layer.out_features( in_features=resnet.fc.in_features) else: out_features = None else: out_features = resnet.fc.in_features modules += [Flatten()] self.out_features = out_features self.encoder = nn.Sequential(*modules)