예제 #1
0
파일: base.py 프로젝트: tc-wolf/finetune
    def load(path, *args, **kwargs):
        """
        Load a saved fine-tuned model from disk.  Path provided should be a folder which contains .pkl and tf.Saver() files

        :param path: string path name to load model from.  Same value as previously provided to :meth:`save`. Must be a folder.
        :param **kwargs: key-value pairs of config items to override.
        """
        if type(path) != str and not hasattr(path, "write"):
            instance = path
            raise FinetuneError(
                'The .load() method can only be called on the class, not on an instance. Try `{}.load("{}") instead.'
                .format(instance.__class__.__name__, args[0]))

        assert_valid_config(**kwargs)

        saver = Saver()
        model = saver.load(path)

        # Backwards compatability
        # Ensure old models get new default settings
        for setting, default in get_default_config().items():
            if not hasattr(model.config, setting):
                if setting == "add_eos_bos_to_chunk":
                    model.config.add_eos_bos_to_chunk = False
                else:
                    model.config.update({setting: default})

        model.config.update(kwargs)
        model.input_pipeline.config = model.config
        download_data_if_required(model.config.base_model)
        saver.set_fallback(model.config.base_model_path)
        model._initialize()
        model.saver.variables = saver.variables
        model._trained = True
        return model
예제 #2
0
    def load(cls, path):
        """
        Load a saved fine-tuned model from disk.  Path provided should be a folder which contains .pkl and tf.Saver() files

        :param path: string path name to load model from.  Same value as previously provided to :meth:`save`. Must be a folder.
        """
        saver = Saver(JL_BASE)
        model = saver.load(path)
        model._initialize()
        model.saver.variables = saver.variables
        return model
예제 #3
0
파일: base.py 프로젝트: bin2000/finetune
    def load(cls, path):
        """
        Load a saved fine-tuned model from disk.  Path provided should be a folder which contains .pkl and tf.Saver() files

        :param path: string path name to load model from.  Same value as previously provided to :meth:`save`. Must be a folder.
        """
        saver = Saver(JL_BASE)
        model = saver.load(path)
        model._initialize()
        model.saver.variables = saver.variables
        tf.reset_default_graph()
        return model
예제 #4
0
파일: base.py 프로젝트: RossSong/finetune
    def load(cls, path, **kwargs):
        """
        Load a saved fine-tuned model from disk.  Path provided should be a folder which contains .pkl and tf.Saver() files

        :param path: string path name to load model from.  Same value as previously provided to :meth:`save`. Must be a folder.
        :param **kwargs: key-value pairs of config items to override.
        """
        assert_valid_config(**kwargs)
        download_data_if_required()
        saver = Saver()
        model = saver.load(path)
        model.config.update(kwargs)
        saver.set_fallback(model.config.base_model_path)
        model._initialize()
        model.saver.variables = saver.variables
        return model
예제 #5
0
    def load(path, *args, **kwargs):
        """
        Load a saved fine-tuned model from disk.  Path provided should be a folder which contains .pkl and tf.Saver() files

        :param path: string path name to load model from.  Same value as previously provided to :meth:`save`. Must be a folder.
        :param **kwargs: key-value pairs of config items to override.
        """
        if type(path) != str:
            instance = path
            raise FinetuneError(
                "The .load() method can only be called on the class, not on an instance. Try `{}.load(\"{}\") instead.".format(
                    instance.__class__.__name__, args[0]
                )
            )

        assert_valid_config(**kwargs)
        saver = Saver()
        model = saver.load(path)
        model.config.update(kwargs)
        download_data_if_required(model.config.base_model)
        saver.set_fallback(model.config.base_model_path)
        model._initialize()
        model.saver.variables = saver.variables
        return model