示例#1
0
文件: base.py 项目: tc-wolf/finetune
    def load(path, *args, **kwargs):
        """
        Load a saved fine-tuned model from disk.  Path provided should be a folder which contains .pkl and tf.Saver() files

        :param path: string path name to load model from.  Same value as previously provided to :meth:`save`. Must be a folder.
        :param **kwargs: key-value pairs of config items to override.
        """
        if type(path) != str and not hasattr(path, "write"):
            instance = path
            raise FinetuneError(
                'The .load() method can only be called on the class, not on an instance. Try `{}.load("{}") instead.'
                .format(instance.__class__.__name__, args[0]))

        assert_valid_config(**kwargs)

        saver = Saver()
        model = saver.load(path)

        # Backwards compatability
        # Ensure old models get new default settings
        for setting, default in get_default_config().items():
            if not hasattr(model.config, setting):
                if setting == "add_eos_bos_to_chunk":
                    model.config.add_eos_bos_to_chunk = False
                else:
                    model.config.update({setting: default})

        model.config.update(kwargs)
        model.input_pipeline.config = model.config
        download_data_if_required(model.config.base_model)
        saver.set_fallback(model.config.base_model_path)
        model._initialize()
        model.saver.variables = saver.variables
        model._trained = True
        return model
示例#2
0
文件: base.py 项目: tc-wolf/finetune
    def __init__(self, **kwargs):
        """
        For a full list of configuration options, see `finetune.config`.

        :param config: A config object generated by `finetune.config.get_config` or None (for default config).
        :param **kwargs: key-value pairs of config items to override.
        """
        weak_self = weakref.ref(self)

        def cleanup():
            strong_self = weak_self()
            if strong_self is not None:
                BaseModel.__del__(strong_self)

        atexit.register(cleanup)
        d = deepcopy(self.defaults)
        d.update(kwargs)
        self.config = get_config(**d)
        self.resolved_gpus = None
        self.validate_config()
        download_data_if_required(self.config.base_model)
        self.input_pipeline = self._get_input_pipeline()
        self._trained = False
        self._initialize()
        if self.config.debugging_logs:
            os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
            tf_logging.set_verbosity(tf_logging.DEBUG)
示例#3
0
    def __init__(self, **kwargs):
        """
        For a full list of configuration options, see `finetune.config`.

        :param config: A config object generated by `finetune.config.get_config` or None (for default config).
        :param **kwargs: key-value pairs of config items to override.
        """
        weak_self = weakref.ref(self)

        def cleanup():
            strong_self = weak_self()
            if strong_self is not None:
                BaseModel.__del__(strong_self)

        atexit.register(cleanup)


        self.config = get_config(**kwargs)
        if self.config.default_context is not None and type(self.config.default_context) != dict:
            raise FinetuneError(
                "Invalid default given: Need a dictionary of auxiliary info fields and default values."
            )
        self.config.use_auxiliary_info = self.config.default_context is not None

        self.resolved_gpus = None
        self.validate_config()
        download_data_if_required(self.config.base_model)
        self.input_pipeline = self._get_input_pipeline()
        self._trained = False
        self._initialize()
        if self.config.debugging_logs:
            os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
            tf_logging.set_verbosity(tf_logging.DEBUG)
示例#4
0
文件: base.py 项目: RossSong/finetune
    def load(cls, path, **kwargs):
        """
        Load a saved fine-tuned model from disk.  Path provided should be a folder which contains .pkl and tf.Saver() files

        :param path: string path name to load model from.  Same value as previously provided to :meth:`save`. Must be a folder.
        :param **kwargs: key-value pairs of config items to override.
        """
        assert_valid_config(**kwargs)
        download_data_if_required()
        saver = Saver()
        model = saver.load(path)
        model.config.update(kwargs)
        saver.set_fallback(model.config.base_model_path)
        model._initialize()
        model.saver.variables = saver.variables
        return model
示例#5
0
    def load(path, *args, **kwargs):
        """
        Load a saved fine-tuned model from disk.  Path provided should be a folder which contains .pkl and tf.Saver() files

        :param path: string path name to load model from.  Same value as previously provided to :meth:`save`. Must be a folder.
        :param **kwargs: key-value pairs of config items to override.
        """
        if type(path) != str:
            instance = path
            raise FinetuneError(
                "The .load() method can only be called on the class, not on an instance. Try `{}.load(\"{}\") instead.".format(
                    instance.__class__.__name__, args[0]
                )
            )

        assert_valid_config(**kwargs)
        saver = Saver()
        model = saver.load(path)
        model.config.update(kwargs)
        download_data_if_required(model.config.base_model)
        saver.set_fallback(model.config.base_model_path)
        model._initialize()
        model.saver.variables = saver.variables
        return model