Exemplo n.º 1
0
def _get_torchvision_model(name, num_classes, pretrained=True, checkpoint_path=None):
    model_constructor = getattr(models, name, None)
    if model_constructor is None or isinstance(model_constructor, ModuleType):
        # constructor doesn't exist or is a submodule instead of function in torchvision
        raise ValueError("Torchvision model {} not found".format(name))
    # build model
    model = model_constructor(pretrained=False, num_classes=num_classes)
    if pretrained and not checkpoint_path:
        pretrained_model = model_constructor(pretrained=True, num_classes=1000)
        # fix num classes mismatch
        if num_classes == 1000:
            model = pretrained_model
        else:
            _load_matched_weights(model, pretrained_model)
        del pretrained_model

    if checkpoint_path is not None:
        load_model(checkpoint_path, model)
    return model
Exemplo n.º 2
0
        def wrapper(
            pretrained_path: str = None,
            pretrained: Union[bool, str] = False,
            pretrained_dataset: str = None,
            load_strict: bool = True,
            ignore_error_tensors: List[str] = None,
            *args,
            **kwargs,
        ):
            """
            :param pretrained_path: A path to the pretrained weights to load,
                if provided will override the pretrained param
            :param pretrained: True to load the default pretrained weights,
                a string to load a specific pretrained weight
                (ex: base, optim, optim-perf),
                or False to not load any pretrained weights
            :param pretrained_dataset: The dataset to load pretrained weights for
                (ex: imagenet, mnist, etc).
                If not supplied will default to the one preconfigured for the model.
            :param load_strict: True to raise an error on issues with state dict
                loading from pretrained_path or pretrained, False to ignore
            :param ignore_error_tensors: Tensors to ignore while checking the state dict
                for weights loaded from pretrained_path or pretrained
            """
            attributes = ModelRegistry._ATTRIBUTES[key]

            if attributes.args and pretrained in attributes.args:
                kwargs[attributes.args[pretrained]
                       [0]] = attributes.args[pretrained][1]

            model = const_func(*args, **kwargs)
            ignore = []

            if ignore_error_tensors:
                ignore.extend(ignore_error_tensors)
            elif attributes.ignore_error_tensors:
                ignore.extend(attributes.ignore_error_tensors)

            if isinstance(pretrained, str):
                if pretrained.lower() == "true":
                    pretrained = True
                elif pretrained.lower() in ["false", "none"]:
                    pretrained = False

            if pretrained_path:
                load_model(pretrained_path, model, load_strict, ignore)
            elif pretrained:
                zoo_model = ModelRegistry.create_zoo_model(
                    key, pretrained, pretrained_dataset)
                try:
                    paths = zoo_model.download_framework_files(
                        extensions=[".pth"])
                    load_model(paths[0], model, load_strict, ignore)
                except Exception:
                    # try one more time with overwrite on in case file was corrupted
                    paths = zoo_model.download_framework_files(
                        overwrite=True, extensions=[".pth"])
                    load_model(paths[0], model, load_strict, ignore)

            return model
Exemplo n.º 3
0
    def wrapper(
        pretrained_path: str = None,
        pretrained: Union[bool, str] = False,
        pretrained_dataset: str = None,
        load_strict: bool = True,
        ignore_error_tensors: List[str] = None,
        **kwargs,
    ):
        """
        :param pretrained_path: A path to the pretrained weights to load,
            if provided will override the pretrained param. May also be
            a SparseZoo stub path preceded by 'zoo:' with the optional
            `?recipe_type=` argument. If given a recipe type, the base
                model weights for that recipe will be loaded
        :param pretrained: True to load the default pretrained weights,
            a string to load a specific pretrained weight
            (ex: base, pruned-moderate),
            or False to not load any pretrained weights
        :param pretrained_dataset: The dataset to load pretrained weights for
            (ex: imagenet, mnist, etc).
            If not supplied will default to the one preconfigured for the model.
        :param load_strict: True to raise an error on issues with state dict
            loading from pretrained_path or pretrained, False to ignore
        :param ignore_error_tensors: Tensors to ignore while checking the state dict
            for weights loaded from pretrained_path or pretrained
        """
        if isinstance(pretrained, str):
            if pretrained.lower() == "true":
                pretrained = True
            elif pretrained.lower() in ["false", "none"]:
                pretrained = False

        pretrained_torchvision = pretrained is True and not pretrained_path
        model = constructor_function(pretrained=pretrained_torchvision,
                                     **kwargs)
        ignore_error_tensors = ignore_error_tensors or []

        if pretrained_path:
            load_model(pretrained_path, model, load_strict,
                       ignore_error_tensors)
        elif pretrained and not pretrained_torchvision:
            zoo_model = ModelRegistry.create_zoo_model(key, pretrained,
                                                       pretrained_dataset)
            try:
                paths = zoo_model.download_framework_files(extensions=[".pth"])
                load_model(paths[0], model, load_strict, ignore_error_tensors)
            except Exception:
                # try one more time with overwrite on in case file was corrupted
                paths = zoo_model.download_framework_files(overwrite=True,
                                                           extensions=[".pth"])
                load_model(paths[0], model, load_strict, ignore_error_tensors)

        return model