def _efficientnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") cfg_params["blocks_args"] = decode_block_args(cfg_params["blocks_args"]) if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) if common_args: logging.warning( f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" ) kwargs.update(cfg_params) model = EfficientNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format( cfg_settings["num_classes"], kwargs_cls ) ) state_dict["classifier.weight"] = model.state_dict()["classifier.weight"] state_dict["classifier.bias"] = model.state_dict()["classifier.bias"] if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels state_dict["conv_stem.weight"] = repeat_channels( state_dict["conv_stem.weight"], kwargs["in_channels"] ) model.load_state_dict(state_dict) setattr(model, "pretrained_settings", cfg_settings) return model
def _resnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) if common_args: logging.warning( f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" ) kwargs.update(cfg_params) model = ResNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly" .format(cfg_settings["num_classes"], kwargs_cls)) # if there is last_linear in state_dict, it's going to be overwritten state_dict["fc.weight"] = model.state_dict()["last_linear.weight"] state_dict["fc.bias"] = model.state_dict()["last_linear.bias"] # support pretrained for custom input channels if kwargs.get("in_channels", 3) != 3: if "conv1.weight" in state_dict.keys(): old_weights = state_dict["conv1.weight"] name = "conv1.weight" elif "layer0.conv1.weight" in state_dict.keys( ): # fix for se_resne(x)t old_weights = state_dict["layer0.conv1.weight"] name = "layer0.conv1.weight" elif "conv1.1.weight" in state_dict.keys(): # fix for BResNet old_weights = state_dict["conv1.1.weight"] name = "conv1.1.weight" state_dict[name] = repeat_channels( old_weights, new_channels=int(kwargs["in_channels"] / 3 * old_weights.size(1)), old_channels=old_weights.size(1), ) model.load_state_dict(state_dict) if cfg_settings.get("weight_standardization"): # convert to ws implicitly. maybe need a logging warning here? model = conv_to_ws_conv(model) setattr(model, "pretrained_settings", cfg_settings) return model
def _hrnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) assert ( common_args == set() ), "Args {} are going to be overwritten by default params for {} weights".format( common_args, pretrained) kwargs.update(cfg_params) model = HRNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly" .format(cfg_settings["num_classes"], kwargs_cls)) # if there is last_linear in state_dict, it's going to be overwritten if cfg_params.get("OCR", False): state_dict["aux_head.2.weight"] = model.state_dict( )["aux_head.2.weight"] state_dict["aux_head.2.bias"] = model.state_dict( )["aux_head.2.bias"] state_dict["head.weight"] = model.state_dict()["head.weight"] state_dict["head.bias"] = model.state_dict()["head.bias"] else: state_dict["head.2.weight"] = model.state_dict( )["head.2.weight"] state_dict["head.2.bias"] = model.state_dict()["head.2.bias"] # support custom number of input channels if kwargs.get("in_channels", 3) != 3: old_weights = state_dict.get("encoder.conv1.weight") state_dict["encoder.conv1.weight"] = repeat_channels( old_weights, kwargs["in_channels"]) model.load_state_dict(state_dict) # models were trained using inplaceabn. need to adjust for it. it works without # this patch but results are slightly worse patch_inplace_abn(model) setattr(model, "pretrained_settings", cfg_settings) return model
def _resnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) assert ( common_args == set() ), "Args {} are going to be overwritten by default params for {} weights".format( common_args, pretrained) kwargs.update(cfg_params) model = TResNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"], check_hash=True) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly" .format(cfg_settings["num_classes"], kwargs_cls)) # if there is last_linear in state_dict, it's going to be overwritten state_dict["last_linear.weight"] = model.state_dict( )["last_linear.weight"] state_dict["last_linear.bias"] = model.state_dict( )["last_linear.bias"] if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels state_dict["conv1.1.weight"] = repeat_channels( state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16) model.load_state_dict(state_dict) # need to adjust some parameters to be align with original model patch_blur_pool(model) patch_bn(model) setattr(model, "pretrained_settings", cfg_settings) return model
def _resnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) if common_args: logging.warning( f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" ) kwargs.update(cfg_params) model = ResNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly" .format(cfg_settings["num_classes"], kwargs_cls)) # if there is last_linear in state_dict, it's going to be overwritten state_dict["fc.weight"] = model.state_dict()["last_linear.weight"] state_dict["fc.bias"] = model.state_dict()["last_linear.bias"] # support pretrained for custom input channels # layer0. is needed to support se_resne(x)t weights if kwargs.get("in_channels", 3) != 3: old_weights = state_dict.get("conv1.weight") old_weights = state_dict.get( "layer0.conv1.weight") if old_weights is None else old_weights state_dict["layer0.conv1.weight"] = repeat_channels( old_weights, kwargs["in_channels"]) model.load_state_dict(state_dict) setattr(model, "pretrained_settings", cfg_settings) return model