Exemplo n.º 1
0
def _efficientnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    cfg_params["blocks_args"] = decode_block_args(cfg_params["blocks_args"])
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = EfficientNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
            state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
        if kwargs.get("in_channels", 3) != 3:  # support pretrained for custom input channels
            state_dict["conv_stem.weight"] = repeat_channels(
                state_dict["conv_stem.weight"], kwargs["in_channels"]
            )
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model
Exemplo n.º 2
0
def _resnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = ResNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly"
                .format(cfg_settings["num_classes"], kwargs_cls))
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["fc.weight"] = model.state_dict()["last_linear.weight"]
            state_dict["fc.bias"] = model.state_dict()["last_linear.bias"]
        # support pretrained for custom input channels
        if kwargs.get("in_channels", 3) != 3:
            if "conv1.weight" in state_dict.keys():
                old_weights = state_dict["conv1.weight"]
                name = "conv1.weight"
            elif "layer0.conv1.weight" in state_dict.keys(
            ):  # fix for se_resne(x)t
                old_weights = state_dict["layer0.conv1.weight"]
                name = "layer0.conv1.weight"
            elif "conv1.1.weight" in state_dict.keys():  # fix for BResNet
                old_weights = state_dict["conv1.1.weight"]
                name = "conv1.1.weight"
            state_dict[name] = repeat_channels(
                old_weights,
                new_channels=int(kwargs["in_channels"] / 3 *
                                 old_weights.size(1)),
                old_channels=old_weights.size(1),
            )
        model.load_state_dict(state_dict)
        if cfg_settings.get("weight_standardization"):
            # convert to ws implicitly. maybe need a logging warning here?
            model = conv_to_ws_conv(model)
    setattr(model, "pretrained_settings", cfg_settings)
    return model
Exemplo n.º 3
0
def _hrnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    assert (
        common_args == set()
    ), "Args {} are going to be overwritten by default params for {} weights".format(
        common_args, pretrained)
    kwargs.update(cfg_params)
    model = HRNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly"
                .format(cfg_settings["num_classes"], kwargs_cls))
            # if there is last_linear in state_dict, it's going to be overwritten
            if cfg_params.get("OCR", False):
                state_dict["aux_head.2.weight"] = model.state_dict(
                )["aux_head.2.weight"]
                state_dict["aux_head.2.bias"] = model.state_dict(
                )["aux_head.2.bias"]
                state_dict["head.weight"] = model.state_dict()["head.weight"]
                state_dict["head.bias"] = model.state_dict()["head.bias"]
            else:
                state_dict["head.2.weight"] = model.state_dict(
                )["head.2.weight"]
                state_dict["head.2.bias"] = model.state_dict()["head.2.bias"]
        # support custom number of input channels
        if kwargs.get("in_channels", 3) != 3:
            old_weights = state_dict.get("encoder.conv1.weight")
            state_dict["encoder.conv1.weight"] = repeat_channels(
                old_weights, kwargs["in_channels"])
        model.load_state_dict(state_dict)
        # models were trained using inplaceabn. need to adjust for it. it works without
        # this patch but results are slightly worse
        patch_inplace_abn(model)
    setattr(model, "pretrained_settings", cfg_settings)
    return model
Exemplo n.º 4
0
def _resnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    assert (
        common_args == set()
    ), "Args {} are going to be overwritten by default params for {} weights".format(
        common_args, pretrained)
    kwargs.update(cfg_params)
    model = TResNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"],
                                              check_hash=True)
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly"
                .format(cfg_settings["num_classes"], kwargs_cls))
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["last_linear.weight"] = model.state_dict(
            )["last_linear.weight"]
            state_dict["last_linear.bias"] = model.state_dict(
            )["last_linear.bias"]
        if kwargs.get("in_channels",
                      3) != 3:  # support pretrained for custom input channels
            state_dict["conv1.1.weight"] = repeat_channels(
                state_dict["conv1.1.weight"], kwargs["in_channels"] * 16,
                3 * 16)
        model.load_state_dict(state_dict)
        # need to adjust some parameters to be align with original model
        patch_blur_pool(model)
        patch_bn(model)
    setattr(model, "pretrained_settings", cfg_settings)
    return model
Exemplo n.º 5
0
def _resnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = ResNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly"
                .format(cfg_settings["num_classes"], kwargs_cls))
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["fc.weight"] = model.state_dict()["last_linear.weight"]
            state_dict["fc.bias"] = model.state_dict()["last_linear.bias"]
        # support pretrained for custom input channels
        # layer0. is needed to support se_resne(x)t weights
        if kwargs.get("in_channels", 3) != 3:
            old_weights = state_dict.get("conv1.weight")
            old_weights = state_dict.get(
                "layer0.conv1.weight") if old_weights is None else old_weights
            state_dict["layer0.conv1.weight"] = repeat_channels(
                old_weights, kwargs["in_channels"])
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model