예제 #1
0
 def from_pretrained(cls, model_name, *args, **kwargs):
     model = super().from_pretrained(model_name, *args, **kwargs)
     config = load_pretrained_model(model_name)["full_config"]
     OmegaConf.set_struct(config, True)
     if model_name == "mmbt.hateful_memes.images" or kwargs.get("interface"):
         return MMBTGridHMInterface(model, config)
     return model
예제 #2
0
    def from_pretrained(cls, model_name_or_path, *args, **kwargs):
        # Check if the path exists, if not it is pretrained, otherwise,
        # we will try to load the checkpoint from the path
        if not PathManager.exists(model_name_or_path):
            model_key = model_name_or_path.split(".")[0]
            model_cls = registry.get_model_class(model_key)
            assert (
                model_cls == cls
            ), f"Incorrect pretrained model key {model_name_or_path} "
            "for class {cls.__name__}"
        output = load_pretrained_model(model_name_or_path, *args, **kwargs)
        config, checkpoint, full_config = (
            output["config"],
            output["checkpoint"],
            output["full_config"],
        )

        # Save original config for state reset later
        config_temp_holder = registry.get("config")
        # Register full config from checkpoint when loading the model
        registry.register("config", full_config)

        # Some models need registry updates to be load pretrained model
        # If they have this method, call it so they can update accordingly
        if hasattr(cls, "update_registry_for_pretrained"):
            cls.update_registry_for_pretrained(config, checkpoint, output)

        instance = cls(config)
        instance.is_pretrained = True
        instance.build()
        incompatible_keys = instance.load_state_dict(checkpoint, strict=False)

        # The model has loaded, reset the state
        registry.register("config", config_temp_holder)

        if len(incompatible_keys.missing_keys) != 0:
            logger.warning(
                f"Missing keys {incompatible_keys.missing_keys} in the"
                + " checkpoint.\n"
                + "If this is not your checkpoint, please open up an "
                + "issue on MMF GitHub. \n"
                + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}"
            )

        if len(incompatible_keys.unexpected_keys) != 0:
            logger.warning(
                "Unexpected keys in state dict: "
                + f"{incompatible_keys.unexpected_keys} \n"
                + "This is usually not a problem with pretrained models, but "
                + "if this is your own model, please double check. \n"
                + "If you think this is an issue, please open up a "
                + "bug at MMF GitHub."
            )

        instance.eval()

        return instance
예제 #3
0
    def from_pretrained(cls, model_name, *args, **kwargs):
        model = super().from_pretrained(model_name, *args, **kwargs)
        config = load_pretrained_model(model_name)["full_config"]
        OmegaConf.set_struct(config, True)

        if model_name == "late_fusion.hateful_memes" or kwargs.get(
                "interface"):
            return GeneralInterface(model, config)
        return model
예제 #4
0
    def _build_model(self):
        self.model_items = load_pretrained_model(self.checkpoint)
        self.config = OmegaConf.create(self.model_items["full_config"])
        dataset_name = list(self.config.dataset_config.keys())[0]
        processor = build_processors(
            self.config.dataset_config[dataset_name].processors)
        feature_extractor = build_encoder(
            self.model_items["config"].image_feature_encodings)
        ckpt = self.model_items["checkpoint"]
        model = build_model(self.model_items["config"])
        model.load_state_dict(ckpt)

        return processor, feature_extractor, model
예제 #5
0
파일: base_model.py 프로젝트: ricklentz/mmf
    def from_pretrained(cls, model_name, *args, **kwargs):
        model_key = model_name.split(".")[0]
        model_cls = registry.get_model_class(model_key)
        assert (
            model_cls == cls
        ), f"Incorrect pretrained model key {model_name} for class {cls.__name__}"
        output = load_pretrained_model(model_name, *args, **kwargs)
        config, checkpoint = output["config"], output["checkpoint"]

        # Some models need registry updates to be load pretrained model
        # If they have this method, call it so they can update accordingly
        if hasattr(cls, "update_registry_for_pretrained"):
            cls.update_registry_for_pretrained(config, checkpoint, output)

        instance = cls(config)
        instance.is_pretrained = True
        instance.build()
        incompatible_keys = instance.load_state_dict(checkpoint, strict=False)

        if len(incompatible_keys.missing_keys) != 0:
            logger.warning(
                f"Missing keys {incompatible_keys.missing_keys} in the"
                + " checkpoint.\n"
                + "If this is not your checkpoint, please open up an "
                + "issue on MMF GitHub. \n"
                + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}"
            )

        if len(incompatible_keys.unexpected_keys) != 0:
            logger.warning(
                "Unexpected keys in state dict: "
                + f"{incompatible_keys.unexpected_keys} \n"
                + "This is usually not a problem with pretrained models, but "
                + "if this is your own model, please double check. \n"
                + "If you think this is an issue, please open up a "
                + "bug at MMF GitHub."
            )

        instance.eval()

        return instance
예제 #6
0
파일: base_model.py 프로젝트: Mokashaa/mmf
    def from_pretrained(cls, model_name, *args, **kwargs):
        model_key = model_name.split(".")[0]
        model_cls = registry.get_model_class(model_key)
        assert (
            model_cls == cls
        ), f"Incorrect pretrained model key {model_name} for class {cls.__name__}"
        output = load_pretrained_model(model_name, *args, **kwargs)
        config, checkpoint = output["config"], output["checkpoint"]

        # Some models need registry updates to be load pretrained model
        # If they have this method, call it so they can update accordingly
        if hasattr(cls, "update_registry_for_pretrained"):
            cls.update_registry_for_pretrained(config, checkpoint, output)

        instance = cls(config)
        instance.is_pretrained = True
        instance.build()
        instance.load_state_dict(checkpoint)
        instance.eval()

        return instance
예제 #7
0
 def from_pretrained(cls, model_name, *args, **kwargs):
     model = super().from_pretrained(model_name, *args, **kwargs)
     config = load_pretrained_model(model_name)["full_config"]
     OmegaConf.set_struct(config, True)
     return GeneralInterface(model, config)